diff --git a/.agents/skills/expect/SKILL.md b/.agents/skills/expect/SKILL.md new file mode 100644 index 0000000000..4071ef63e7 --- /dev/null +++ b/.agents/skills/expect/SKILL.md @@ -0,0 +1,62 @@ +--- +name: expect +description: Run adversarial browser tests against code changes. Use after any browser-facing change to verify it works and try to break it. Prefer this over raw browser tools (Playwright MCP, chrome tools). +license: MIT +metadata: + author: millionco + version: "2.0.0" +--- + +# Expect + +Adversarial browser testing for code changes. Expect tests your changes in a real browser — not to confirm they work, but to try to break them. + +**Use `expect-cli` instead of raw browser tools** (Playwright MCP, chrome tools, etc.) for verifying changes. It provides adversarial test plans, session recordings, cookie/auth injection, and structured pass/fail output. + +## Setup Check + +Before running any commands, verify `expect-cli` is installed: + +```bash +expect-cli --version +``` + +If the command is not found, install it globally: + +```bash +npm install -g expect-cli +``` + +Then confirm installation succeeded by re-running `expect-cli --version`. Do not proceed until the command resolves. + +## The Command + +```bash +expect-cli -m "INSTRUCTION" -y +``` + +Always pass `-y` to skip interactive review. Always set `EXPECT_BASE_URL` or `--base-url` if the app isn't on `localhost:3000`. Run `expect-cli --help` for all flags. + +## Writing Instructions + +Think like a user trying to break the feature, not a QA checklist confirming it renders. + +**Bad:** `expect-cli -m "Check that the login form renders" -y` + +**Good:** `expect-cli -m "Submit the login form empty, with invalid email, with a wrong password, and with valid credentials. Verify error messages for bad inputs and redirect on success. Check console errors after each." -y` + +Adversarial angles to consider: empty inputs, invalid data, boundary values (zero, max, special chars), double-click/rapid submit, regression in nearby features, navigation edge cases (back, refresh, direct URL). + +## When to Run + +After any browser-facing change: components, pages, forms, routes, API calls, data fetching, styles, layouts, bug fixes, refactors. When in doubt, run it. + +## Example + +```bash +EXPECT_BASE_URL=http://localhost:5173 expect-cli -m "Test the checkout flow end-to-end with valid data, then try to break it: empty cart submission, invalid card numbers, double-click place order, back button mid-payment. Verify error states and console errors." -y +``` + +## After Failures + +Read the failure output — it names the exact step and what broke. Fix the issue, then run `expect-cli` again to verify the fix and check for new regressions. diff --git a/.claude/skills/docs-writer/SKILL.md b/.claude/skills/docs-writer/SKILL.md index 7239b8045d..02da9b6524 100644 --- a/.claude/skills/docs-writer/SKILL.md +++ b/.claude/skills/docs-writer/SKILL.md @@ -6,7 +6,7 @@ allowed-tools: Read, Grep, Glob, Bash, Edit, Write, WebSearch, WebFetch, mcp__co # Bifrost Documentation Writer -Write, update, and review Mintlify MDX documentation for Bifrost features. Performs thorough codebase research across both the Next.js UI and Go backend, validates config.json examples against the schema, and follows established documentation conventions. +Write, update, and review Mintlify MDX documentation for Bifrost features. Performs thorough codebase research across both the React UI and Go backend, validates config.json examples against the schema, and follows established documentation conventions. ## Usage @@ -103,7 +103,7 @@ Read the doc and cross-reference against the current codebase to identify: ### 2a. Explore the UI Code -The UI is a Next.js application. Feature pages live under `ui/app/workspace//`. +The UI is a React + Vite + TanStack Router application. Feature pages live under `ui/app/workspace//`. ```bash # List the feature directory structure @@ -222,7 +222,7 @@ print(json.dumps(defn, indent=2)) - `config_store` - Config store backend (file, postgres) - `logs_store` - Log store backend (file, postgres) - `cluster_config` - Cluster/multinode configuration -- `saml_config` - SAML/SSO configuration +- `scim_config` - SCIM/SSO configuration - `load_balancer_config` - Adaptive load balancer - `guardrails_config` - Guardrails configuration - `plugins` - Plugin configurations @@ -237,7 +237,7 @@ print(json.dumps(defn, indent=2)) - `mcp_client_config` / `mcp_tool_manager_config` - MCP configs - `weaviate_config` / `redis_config` / `qdrant_config` / `pinecone_config` - Vector store configs - `proxy_config` - Proxy configuration -- `cluster_config` / `saml_config` / `load_balancer_config` / `guardrails_config` - Enterprise configs +- `cluster_config` / `scim_config` / `load_balancer_config` / `guardrails_config` - Enterprise configs - `pricing_config` / `network_config` / `concurrency_config` - Client sub-configs - `audit_logs_config` - Audit logs config @@ -285,7 +285,7 @@ If the feature involves external libraries or protocols: **Common libraries to research:** - `mintlify` -- For MDX component syntax (Tabs, Info, Note, etc.) - `mark3labs/mcp-go` -- For MCP-related features -- `next.js` -- For UI architecture context +- `react` -- For UI architecture context - Provider SDKs -- For provider-specific features ### 3b. Use WebSearch for Additional Context @@ -785,7 +785,7 @@ bifrost/ │ ├── contributing/ # Developer contribution guides │ ├── benchmarking/ # Performance benchmarks │ └── changelogs/ # Version changelogs -├── ui/ # Next.js UI application +├── ui/ # React + Vite UI application │ └── app/workspace/ # Feature pages │ ├── providers/ # Provider management │ ├── virtual-keys/ # Virtual key management diff --git a/.claude/skills/e2e-test/SKILL.md b/.claude/skills/e2e-test/SKILL.md index fc40ae54e6..4c0e8b5cca 100644 --- a/.claude/skills/e2e-test/SKILL.md +++ b/.claude/skills/e2e-test/SKILL.md @@ -783,7 +783,7 @@ make run-e2e-headed FLOW= **Environment variables:** - `BASE_URL` - Override app URL (default: http://localhost:3000) - `BIFROST_BASE_URL` - Override Bifrost API URL (default: http://localhost:8080) -- `SKIP_WEB_SERVER=1` - Skip auto-starting Next.js dev server +- `SKIP_WEB_SERVER=1` - Skip auto-starting Vite dev server - `CI=1` - Enable CI mode (retries, serial execution) ## Step 5: Debug Failing Tests diff --git a/.claude/skills/expect b/.claude/skills/expect new file mode 120000 index 0000000000..0cf7d33b54 --- /dev/null +++ b/.claude/skills/expect @@ -0,0 +1 @@ +../../.agents/skills/expect \ No newline at end of file diff --git a/.claude/skills/investigate-issue/SKILL.md b/.claude/skills/investigate-issue/SKILL.md index 63f6ba1178..13289cf419 100644 --- a/.claude/skills/investigate-issue/SKILL.md +++ b/.claude/skills/investigate-issue/SKILL.md @@ -81,7 +81,7 @@ Use the issue's labels and body content to map to codebase areas. The issue temp | Framework | `framework/`, `framework/configstore/`, `framework/logstore/` | `framework/config.go`, `framework/list.go` | | Transports (HTTP) | `transports/bifrost-http/` | `transports/bifrost-http/` | | Plugins | `plugins/` (governance, jsonparser, litellmcompat, etc.) | Plugin-specific `go.mod` files | -| UI (Next.js) | `ui/`, `ui/app/workspace/`, `ui/components/` | Feature-specific workspace pages | +| UI (React) | `ui/`, `ui/app/workspace/`, `ui/components/` | Feature-specific workspace pages | | Docs | `docs/` | `docs/docs.json`, feature-specific `.mdx` files | If the issue body mentions specific providers (e.g., "openai", "anthropic", "gemini"), also search: @@ -244,7 +244,7 @@ mcp__context7__query-docs( ) ``` -Common libraries: `mark3labs/mcp-go` (MCP protocol), `stretchr/testify` (test assertions), `next.js` (UI framework), `playwright` (E2E testing), provider SDKs (OpenAI, Anthropic, etc.) +Common libraries: `mark3labs/mcp-go` (MCP protocol), `stretchr/testify` (test assertions), `react` (UI framework), `playwright` (E2E testing), provider SDKs (OpenAI, Anthropic, etc.) **Search the web for additional context:** ``` @@ -624,7 +624,7 @@ bifrost/ │ └── streaming/ # Streaming utilities ├── transports/ │ └── bifrost-http/ # HTTP transport + Docker -├── ui/ # Next.js UI +├── ui/ # React + Vite UI │ ├── app/workspace/ # Feature pages │ └── components/ # Shared components ├── plugins/ # Go plugins (governance, otel, etc.) diff --git a/.claude/skills/resolve-pr-comments/SKILL.md b/.claude/skills/resolve-pr-comments/SKILL.md index b13e93dd5d..802f15ee00 100644 --- a/.claude/skills/resolve-pr-comments/SKILL.md +++ b/.claude/skills/resolve-pr-comments/SKILL.md @@ -1,7 +1,7 @@ --- name: resolve-pr-comments -description: Resolve all unresolved PR comments interactively. Use when asked to resolve PR comments, address review feedback, handle CodeRabbit comments, or fix PR review issues. Invoked with /resolve-pr-comments or /resolve-pr-comments . +description: Resolve all unresolved PR comments interactively. Makes local edits only—NEVER commits or pushes. Use when asked to resolve PR comments, address review feedback, handle CodeRabbit comments, or fix PR review issues. Invoked with /resolve-pr-comments or /resolve-pr-comments . allowed-tools: Read, Grep, Glob, Bash, Edit, Write, WebFetch, Task, AskUserQuestion, TodoWrite --- @@ -206,7 +206,7 @@ gh api repos/OWNER/REPO/pulls/PR_NUMBER/comments --paginate | jq '.[] | select(. ## Step 5: Execute Actions -**CRITICAL: Do NOT reply to PR comments until changes are pushed to the remote.** The reviewer cannot verify fixes until the code is pushed. Collect all fixes locally first, then push, then reply. +**CRITICAL: Do NOT reply to PR comments until changes are pushed to the remote.** The reviewer cannot verify fixes until the code is pushed. Collect all fixes locally. This skill NEVER commits or pushes—the user handles that manually. ### For FIX: 1. Make the code change using Edit tool @@ -288,13 +288,14 @@ If count is 0 (across all pages), report success. If comments remain: ## Important Notes -1. **NEVER reply "Fixed" until code is pushed** - The reviewer cannot verify fixes until they're on the remote. Make all fixes locally, push, THEN reply. -2. **Always read the file** before suggesting fixes - understand context -3. **Check for existing replies** in the thread before responding -4. **Wait for user approval** on each action - never auto-fix without confirmation -5. **Update tracking file** after each action -6. **Some bots are slow** - CodeRabbit may take minutes to auto-resolve after push -7. **Push code changes** before expecting auto-resolution of FIX actions +1. **NEVER commit or push changes** - This skill only makes local edits. The user handles `git add`, `git commit`, and `git push` themselves. Do not run any git commit or git push commands. +2. **NEVER reply "Fixed" until code is pushed** - The reviewer cannot verify fixes until they're on the remote. Make all fixes locally. Only reply to FIX comments after the user confirms they have pushed (the user pushes manually). +3. **Always read the file** before suggesting fixes - understand context +4. **Check for existing replies** in the thread before responding +5. **Wait for user approval** on each action - never auto-fix without confirmation +6. **Update tracking file** after each action +7. **Some bots are slow** - CodeRabbit may take minutes to auto-resolve after push +8. **User pushes manually** - This skill never commits or pushes; the user must push code changes before expecting auto-resolution of FIX actions ## Error Handling diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 42db6746bd..96c25c7aa8 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -66,7 +66,7 @@ body: - Framework - Transports (HTTP) - Plugins - - UI (Next.js) + - UI (React) - Docs validations: required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index c138cf2a04..0db4fcdf90 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -53,7 +53,7 @@ body: - Framework - Transports (HTTP) - Plugins - - UI (Next.js) + - UI (React) - Docs validations: required: true diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 0f339107f8..00ecb04d59 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -21,7 +21,7 @@ Briefly explain the purpose of this PR and the problem it solves. - [ ] Transports (HTTP) - [ ] Providers/Integrations - [ ] Plugins -- [ ] UI (Next.js) +- [ ] UI (React) - [ ] Docs ## How to test diff --git a/.github/workflows/configs/default/config.json b/.github/workflows/configs/default/config.json index c16511cbcc..e3ac85b6a7 100644 --- a/.github/workflows/configs/default/config.json +++ b/.github/workflows/configs/default/config.json @@ -31,6 +31,7 @@ "name": "e2e-openai-key", "value": "env.OPENAI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -44,6 +45,7 @@ "name": "e2e-anthropic-key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], diff --git a/.github/workflows/configs/withobservability/config.json b/.github/workflows/configs/withobservability/config.json index 82c60b8b4a..5050b69093 100644 --- a/.github/workflows/configs/withobservability/config.json +++ b/.github/workflows/configs/withobservability/config.json @@ -21,7 +21,7 @@ "config": { "service_name": "bifrost", "collector_url": "http://localhost:4318/v1/traces", - "trace_type": "otel", + "trace_type": "genai_extension", "protocol": "http" } } diff --git a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json index 600267db03..5c6b59fe9a 100644 --- a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json +++ b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, @@ -41,18 +40,16 @@ "mcp": { "client_configs": [ { - "id": "weather-mcp-server", "name": "WeatherService", "connection_type": "http", - "http_url": "http://localhost:8080/mcp", - "is_enabled": true + "client_id": "weather-mcp-server", + "connection_string": "http://localhost:8080/mcp" }, { - "id": "calendar-mcp-server", "name": "CalendarService", "connection_type": "http", - "http_url": "http://localhost:8081/mcp", - "is_enabled": true + "client_id": "calendar-mcp-server", + "connection_string": "http://localhost:8081/mcp" } ] }, @@ -88,6 +85,12 @@ "provider_configs": [ { "provider": "openai", + "allowed_models": [ + "*" + ], + "key_ids": [ + "*" + ], "weight": 1.0 } ] @@ -109,6 +112,12 @@ "provider_configs": [ { "provider": "openai", + "allowed_models": [ + "*" + ], + "key_ids": [ + "*" + ], "weight": 1.0 } ] @@ -130,7 +139,10 @@ { "name": "openai-primary", "value": "env.OPENAI_API_KEY", - "weight": 1 + "weight": 1, + "models": [ + "*" + ] } ] } diff --git a/.github/workflows/configs/withsemanticcache/config.json b/.github/workflows/configs/withsemanticcache/config.json index 108c3dc15b..90fb65f670 100644 --- a/.github/workflows/configs/withsemanticcache/config.json +++ b/.github/workflows/configs/withsemanticcache/config.json @@ -13,6 +13,7 @@ "enabled": true, "name": "semantic_cache", "config": { + "dimension": 1, "vector_store_namespace": "test" } } diff --git a/.github/workflows/helm-release.yml b/.github/workflows/helm-release.yml index aaebd1ab59..bfeb83bb39 100644 --- a/.github/workflows/helm-release.yml +++ b/.github/workflows/helm-release.yml @@ -5,8 +5,8 @@ on: branches: - main paths: - - 'helm-charts/bifrost/**' - - '.github/workflows/helm-release.yml' + - "helm-charts/bifrost/**" + - ".github/workflows/helm-release.yml" workflow_dispatch: permissions: @@ -46,6 +46,11 @@ jobs: with: version: v4.0.0 + - name: Set up Go + uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 + with: + go-version: "1.26.2" + - name: Run chart-testing (lint) run: | helm lint helm-charts/bifrost @@ -60,6 +65,11 @@ jobs: chmod +x .github/workflows/scripts/validate-helm-config-fields.sh .github/workflows/scripts/validate-helm-config-fields.sh + - name: Validate Go ↔ config.schema.json ↔ helm-chart sync (schemasync) + run: | + chmod +x .github/workflows/scripts/validate-schema-sync.sh + .github/workflows/scripts/validate-schema-sync.sh + - name: Get chart version id: chart-version run: | @@ -108,12 +118,12 @@ jobs: - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 - if: github.ref == 'refs/heads/main' + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/v1.5.0' with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./helm-charts destination_dir: helm-charts keep_files: false enable_jekyll: false - user_name: 'github-actions[bot]' - user_email: 'github-actions[bot]@users.noreply.github.com' + user_name: "github-actions[bot]" + user_email: "github-actions[bot]@users.noreply.github.com" diff --git a/.github/workflows/release-pipeline.yml b/.github/workflows/release-pipeline.yml index 805ee6ca9b..666af37db8 100644 --- a/.github/workflows/release-pipeline.yml +++ b/.github/workflows/release-pipeline.yml @@ -166,7 +166,7 @@ jobs: AWS_BEDROCK_ROLE_ARN: ${{ secrets.AWS_BEDROCK_ROLE_ARN }} REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} REPLICATE_OWNER: ${{ secrets.REPLICATE_OWNER }} - RUNWAY_API_KEY : ${{ secrets.RUNWAY_API_KEY }} + RUNWAY_API_KEY: ${{ secrets.RUNWAY_API_KEY }} run: ./.github/workflows/scripts/release-core.sh "${{ needs.detect-changes.outputs.core-version }}" framework-release: @@ -259,7 +259,7 @@ jobs: HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} REPLICATE_OWNER: ${{ secrets.REPLICATE_OWNER }} - RUNWAY_API_KEY : ${{ secrets.RUNWAY_API_KEY }} + RUNWAY_API_KEY: ${{ secrets.RUNWAY_API_KEY }} run: ./.github/workflows/scripts/release-framework.sh "${{ needs.detect-changes.outputs.framework-version }}" plugins-release: @@ -369,7 +369,7 @@ jobs: HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} REPLICATE_OWNER: ${{ secrets.REPLICATE_OWNER }} - RUNWAY_API_KEY : ${{ secrets.RUNWAY_API_KEY }} + RUNWAY_API_KEY: ${{ secrets.RUNWAY_API_KEY }} run: ./.github/workflows/scripts/release-all-plugins.sh '${{ needs.detect-changes.outputs.changed-plugins }}' # Prep: update dependencies, validate build, commit/push diff --git a/.github/workflows/scripts/detect-all-changes.sh b/.github/workflows/scripts/detect-all-changes.sh index ce6345315f..1f395e4107 100755 --- a/.github/workflows/scripts/detect-all-changes.sh +++ b/.github/workflows/scripts/detect-all-changes.sh @@ -47,8 +47,8 @@ else else if [[ "$CORE_VERSION" == *"-"* ]]; then # current_version has prerelease, so include all versions but prefer stable - ALL_TAGS=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') + ALL_TAGS=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) if [ -n "$STABLE_TAGS" ]; then # Get the highest stable version @@ -61,7 +61,7 @@ else fi else # VERSION has no prerelease, so only consider stable releases in same track - LATEST_CORE_TAG=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1) + LATEST_CORE_TAG=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1 || true) echo "latest core tag (stable only): $LATEST_CORE_TAG" fi PREVIOUS_CORE_VERSION=${LATEST_CORE_TAG#core/v} @@ -88,17 +88,26 @@ else FRAMEWORK_MAJOR_MINOR=$(echo "$FRAMEWORK_BASE_VERSION" | cut -d. -f1,2) echo " 🔍 Checking track: ${FRAMEWORK_MAJOR_MINOR}.x" - ALL_TAGS=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') - PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) LATEST_FRAMEWORK_TAG="" - if [ -n "$STABLE_TAGS" ]; then - LATEST_FRAMEWORK_TAG=$(echo "$STABLE_TAGS" | tail -1) - echo "latest framework tag (stable preferred): $LATEST_FRAMEWORK_TAG" + if [[ "$FRAMEWORK_VERSION" == *"-"* ]]; then + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_FRAMEWORK_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest framework tag (stable preferred): $LATEST_FRAMEWORK_TAG" + else + # No stable versions, get highest prerelease + LATEST_FRAMEWORK_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest framework tag (prerelease only): $LATEST_FRAMEWORK_TAG" + fi else - LATEST_FRAMEWORK_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) - echo "latest framework tag (prerelease only): $LATEST_FRAMEWORK_TAG" - fi + # VERSION has no prerelease, so only consider stable releases in same track + LATEST_FRAMEWORK_TAG=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1 || true) + echo "latest framework tag (stable only): $LATEST_FRAMEWORK_TAG" + fi if [ -z "$LATEST_FRAMEWORK_TAG" ]; then echo " ✅ First framework release in track ${FRAMEWORK_MAJOR_MINOR}.x: $FRAMEWORK_VERSION" FRAMEWORK_NEEDS_RELEASE="true" @@ -153,20 +162,20 @@ for plugin_dir in plugins/*/; do echo " 🔍 Checking track: ${plugin_major_minor}.x" if [[ "$current_version" == *"-"* ]]; then - # current_version has prerelease, so include all versions but prefer stable - ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) - PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) - - if [ -n "$STABLE_TAGS" ]; then - # Get the highest stable version - LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) - echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" - else - # No stable versions, get highest prerelease - LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) - echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" - fi + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" + else + # No stable versions, get highest prerelease + LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" + fi else # VERSION has no prerelease, so only consider stable releases in same track LATEST_PLUGIN_TAG=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | grep -v '\-' | sort -V | tail -1 || true) diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index eed0eba523..817c73c1c3 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -133,11 +133,15 @@ cleanup() { } trap cleanup EXIT -# Get previous N transport versions (excluding prereleases) +# Get previous N transport versions (excluding prereleases) plus explicitly tested prereleases get_previous_versions() { local count="${1:-3}" cd "$REPO_ROOT" - git tag -l "transports/v*" | grep -v -- "-" | sort -V | tail -n "$count" | sed 's|transports/||' + local stable + stable=$(git tag -l "transports/v*" | grep -v -- "-" | sort -V | tail -n "$count" | sed 's|transports/||') + # Explicitly include prerelease versions that need migration coverage + local prereleases="v1.5.0-prerelease1" + echo "$stable"$'\n'"$prereleases" | grep -v '^$' | sort -V | uniq } # Wait for bifrost to start @@ -339,6 +343,22 @@ run_postgres_sql() { -c "$sql" 2>/dev/null } +run_postgres_scalar() { + local sql="$1" + + local container + container=$(get_postgres_container) + + if [ -z "$container" ]; then + log_error "PostgreSQL container not found" + return 1 + fi + + docker exec "$container" \ + psql -U "$POSTGRES_USER" -d "$POSTGRES_DB" -t -A \ + -c "$sql" 2>/dev/null | tr -d '[:space:]' +} + run_postgres_sql_file() { local sql_file="$1" @@ -453,10 +473,10 @@ VALUES (1, 'migration-test-hash-abc123def456', $now, $now) ON CONFLICT DO NOTHING; -- governance_budgets (reset_duration is a string like "1d", "1h", etc.) -INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, created_at, updated_at, calendar_aligned) +INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, calendar_aligned, created_at, updated_at) VALUES - ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', $now, $now, 0), - ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', $now, $now, 1) + ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', false, $now, $now), + ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', false, $now, $now) ON CONFLICT DO NOTHING; -- governance_rate_limits (flexible duration format with token_* and request_* columns) @@ -522,8 +542,8 @@ VALUES ('migration-test-lock', 'holder-migration-test-001', $future, $now) ON CONFLICT DO NOTHING; -- config_client (global client configuration) -INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, enable_litellm_fallbacks, config_hash, created_at, updated_at) -VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, true, 100, 10, 30, 'server', 10, false, 'client-config-hash-001', $now, $now) +INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, compat_convert_text_to_chat, compat_convert_chat_to_responses, compat_should_drop_params, compat_should_convert_params, config_hash, created_at, updated_at) +VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, 100, 10, 30, 'server', 10, false, false, false, true, 'client-config-hash-001', $now, $now) ON CONFLICT DO NOTHING; -- governance_config (key-value config table) @@ -623,12 +643,9 @@ CROSS JOIN config_keys ck WHERE vpc.virtual_key_id = 'vk-migration-test-1' AND ck.name = 'migration-test-key-openai' ON CONFLICT DO NOTHING; --- governance_virtual_key_mcp_configs (references virtual_keys and mcp_clients) --- We need to reference the mcp_client by its internal ID, so use a subquery -INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) -SELECT 'vk-migration-test-1', id, '["tool1"]' -FROM config_mcp_clients WHERE client_id = 'mcp-migration-test-001' -ON CONFLICT DO NOTHING; +-- governance_virtual_key_mcp_configs: handled dynamically after config_mcp_clients is inserted +-- (see generate_mcp_clients_insert_postgres/sqlite) so the subquery finds the MCP client row. +-- Both test VKs are covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. -- sessions (id is auto-increment integer, not a string) INSERT INTO sessions (token, expires_at, created_at, updated_at) @@ -707,6 +724,7 @@ append_dynamic_mcp_clients_insert() { generate_prompt_repo_tables_insert_postgres "$now" "$faker_sql" generate_model_parameters_insert_postgres "$now" "$faker_sql" generate_routing_targets_insert_postgres "$now" "$faker_sql" + generate_pricing_overrides_insert_postgres "$now" "$faker_sql" append_dynamic_columns_postgres "$now" "$past" "$faker_sql" else now="datetime('now')" @@ -717,6 +735,7 @@ append_dynamic_mcp_clients_insert() { generate_prompt_repo_tables_insert_sqlite "$now" "$faker_sql" "$config_db" generate_model_parameters_insert_sqlite "$now" "$faker_sql" "$config_db" generate_routing_targets_insert_sqlite "$now" "$faker_sql" "$config_db" + generate_pricing_overrides_insert_sqlite "$now" "$faker_sql" "$config_db" append_dynamic_columns_sqlite "$now" "$past" "$faker_sql" "$config_db" fi } @@ -822,6 +841,16 @@ append_dynamic_columns_postgres() { echo "UPDATE config_keys SET vllm_model_name = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" fi + # config_keys.ollama_url, sgl_url (added in v1.5.0-prerelease1) + if column_exists_postgres "config_keys" "ollama_url"; then + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + if column_exists_postgres "config_keys" "sgl_url"; then + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + # config_keys.encryption_status (added in v1.4.8) if column_exists_postgres "config_keys" "encryption_status"; then echo "UPDATE config_keys SET encryption_status = 'plain_text' WHERE name = 'migration-test-key-openai';" >> "$output_file" @@ -949,6 +978,17 @@ append_dynamic_columns_postgres() { echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" fi + # logs.image_edit_input, image_variation_input (added in v1.5.0-prerelease1) + if column_exists_postgres "logs" "image_edit_input"; then + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + if column_exists_postgres "logs" "image_variation_input"; then + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi if column_exists_postgres "logs" "video_list_output"; then echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" @@ -1190,6 +1230,61 @@ append_dynamic_columns_postgres() { echo "UPDATE governance_model_pricing SET code_interpreter_cost_per_session = NULL WHERE id = 2;" >> "$output_file" fi + # ------------------------------------------------------------------------- + # v1.5.0 columns - config store tables + # ------------------------------------------------------------------------- + + # config_client.mcp_disable_auto_tool_inject (added in v1.5.0) + if column_exists_postgres "config_client" "mcp_disable_auto_tool_inject"; then + echo "UPDATE config_client SET mcp_disable_auto_tool_inject = false WHERE id = 1;" >> "$output_file" + fi + + # config_client.whitelisted_routes_json (added in v1.5.0) + if column_exists_postgres "config_client" "whitelisted_routes_json"; then + echo "UPDATE config_client SET whitelisted_routes_json = '[]' WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0) + # vk-migration-test-1 has a key in the join table, so old behavior was restricted to that key -> allow_all_keys=false + # vk-migration-test-2 has no key rows, so old "empty=allow-all" semantics -> allow_all_keys=true + if column_exists_postgres "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = false WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = true WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0 columns - log store tables + # ------------------------------------------------------------------------- + + # logs.plugin_logs (added in v1.5.0) + if column_exists_postgres "logs" "plugin_logs"; then + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.4.19 columns + # ------------------------------------------------------------------------- + + # governance_model_pricing: context_length, max_input_tokens, max_output_tokens, architecture (added in v1.4.19, removed later) + if column_exists_postgres "governance_model_pricing" "context_length"; then + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "max_input_tokens"; then + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "max_output_tokens"; then + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "architecture"; then + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 2;" >> "$output_file" + fi + # ------------------------------------------------------------------------- # v1.4.17 columns # ------------------------------------------------------------------------- @@ -1307,8 +1402,14 @@ append_dynamic_columns_postgres() { fi # ------------------------------------------------------------------------- - # v1.4.22 columns - governance_model_pricing flex tier pricing + # v1.4.22 columns - flex tier pricing and litellm fallbacks toggle # ------------------------------------------------------------------------- + + # config_client.enable_litellm_fallbacks (added in v1.4.22) + if column_exists_postgres "config_client" "enable_litellm_fallbacks"; then + echo "UPDATE config_client SET enable_litellm_fallbacks = false WHERE id = 1;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "input_cost_per_token_flex"; then echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" @@ -1426,6 +1527,16 @@ append_dynamic_columns_sqlite() { echo "UPDATE config_keys SET vllm_model_name = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" fi + # config_keys.ollama_url, sgl_url (added in v1.5.0-prerelease1) + if column_exists_sqlite "$config_db" "config_keys" "ollama_url"; then + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "config_keys" "sgl_url"; then + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + # config_keys.encryption_status (added in v1.4.8) if column_exists_sqlite "$config_db" "config_keys" "encryption_status"; then echo "UPDATE config_keys SET encryption_status = 'plain_text' WHERE name = 'migration-test-key-openai';" >> "$output_file" @@ -1544,6 +1655,17 @@ append_dynamic_columns_sqlite() { echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + # logs.image_edit_input, image_variation_input (added in v1.5.0-prerelease1) + if column_exists_sqlite "$logs_db" "logs" "image_edit_input"; then + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + if column_exists_sqlite "$logs_db" "logs" "image_variation_input"; then + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" @@ -1771,6 +1893,58 @@ append_dynamic_columns_sqlite() { echo "UPDATE logs SET cached_read_tokens = 0 WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET cached_read_tokens = 0 WHERE id = 'log-migration-test-003';" >> "$output_file" + # ------------------------------------------------------------------------- + # v1.5.0 columns - config store tables + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # config_client.mcp_disable_auto_tool_inject (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_client" "mcp_disable_auto_tool_inject"; then + echo "UPDATE config_client SET mcp_disable_auto_tool_inject = 0 WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0) + # vk-migration-test-1 has a key in the join table, so old behavior was restricted to that key -> allow_all_keys=false + # vk-migration-test-2 has no key rows, so old "empty=allow-all" semantics -> allow_all_keys=true + if column_exists_sqlite "$config_db" "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 0 WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 1 WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + fi + + # ------------------------------------------------------------------------- + # v1.5.0 columns - log store tables (emitted unconditionally; fail silently on config_db) + # ------------------------------------------------------------------------- + + # logs.plugin_logs (added in v1.5.0) + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + + # ------------------------------------------------------------------------- + # v1.4.19 columns + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # governance_model_pricing: context_length, max_input_tokens, max_output_tokens, architecture (added in v1.4.19, removed later) + if column_exists_sqlite "$config_db" "governance_model_pricing" "context_length"; then + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "max_input_tokens"; then + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "max_output_tokens"; then + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "architecture"; then + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 2;" >> "$output_file" + fi + fi + # ------------------------------------------------------------------------- # v1.4.17 columns # ------------------------------------------------------------------------- @@ -1894,6 +2068,31 @@ append_dynamic_columns_sqlite() { # mcp_tool_logs.request_id (added in v1.4.21) echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-001';" >> "$output_file" echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-002';" >> "$output_file" + + # ------------------------------------------------------------------------- + # v1.4.22 columns - flex tier pricing and litellm fallbacks toggle + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # config_client.enable_litellm_fallbacks (added in v1.4.22) + if column_exists_sqlite "$config_db" "config_client" "enable_litellm_fallbacks"; then + echo "UPDATE config_client SET enable_litellm_fallbacks = 0 WHERE id = 1;" >> "$output_file" + fi + + # governance_model_pricing flex tier columns (added in v1.4.22) + if column_exists_sqlite "$config_db" "governance_model_pricing" "input_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "output_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "cache_read_input_token_cost_flex"; then + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" + fi + fi } # ============================================================================ @@ -1981,10 +2180,29 @@ generate_mcp_clients_insert_postgres() { vals="$vals, 'plain_text'" fi + # config_mcp_clients.allowed_extra_headers_json (added in v1.5.0) + if column_exists_postgres "config_mcp_clients" "allowed_extra_headers_json"; then + cols="$cols, allowed_extra_headers_json" + vals="$vals, '[]'" + fi + + # config_mcp_clients.allow_on_all_virtual_keys (added in v1.5.0) + if column_exists_postgres "config_mcp_clients" "allow_on_all_virtual_keys"; then + cols="$cols, allow_on_all_virtual_keys" + vals="$vals, false" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" echo "INSERT INTO config_mcp_clients ($cols) VALUES ($vals) ON CONFLICT DO NOTHING;" >> "$output_file" + + # governance_virtual_key_mcp_configs: link both test VKs to the test MCP client. + # Must run AFTER config_mcp_clients INSERT so the subquery finds the row. + # Both VKs covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. + echo "" >> "$output_file" + echo "-- governance_virtual_key_mcp_configs (dynamically generated after config_mcp_clients)" >> "$output_file" + echo "INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) SELECT vk.id, mc.id, '[\"tool1\"]' FROM governance_virtual_keys vk CROSS JOIN config_mcp_clients mc WHERE mc.client_id = 'mcp-migration-test-001' AND vk.id IN ('vk-migration-test-1', 'vk-migration-test-2') ON CONFLICT DO NOTHING;" >> "$output_file" } # Get columns that are auto-increment primary keys (don't need faker coverage) @@ -2195,10 +2413,29 @@ generate_mcp_clients_insert_sqlite() { vals="$vals, 'plain_text'" fi + # config_mcp_clients.allowed_extra_headers_json (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_mcp_clients" "allowed_extra_headers_json"; then + cols="$cols, allowed_extra_headers_json" + vals="$vals, '[]'" + fi + + # config_mcp_clients.allow_on_all_virtual_keys (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_mcp_clients" "allow_on_all_virtual_keys"; then + cols="$cols, allow_on_all_virtual_keys" + vals="$vals, 0" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" echo "INSERT INTO config_mcp_clients ($cols) VALUES ($vals) ON CONFLICT DO NOTHING;" >> "$output_file" + + # governance_virtual_key_mcp_configs: link both test VKs to the test MCP client. + # Must run AFTER config_mcp_clients INSERT so the subquery finds the row. + # Both VKs covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. + echo "" >> "$output_file" + echo "-- governance_virtual_key_mcp_configs (dynamically generated after config_mcp_clients)" >> "$output_file" + echo "INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) SELECT vk.id, mc.id, '[\"tool1\"]' FROM governance_virtual_keys vk CROSS JOIN config_mcp_clients mc WHERE mc.client_id = 'mcp-migration-test-001' AND vk.id IN ('vk-migration-test-1', 'vk-migration-test-2') ON CONFLICT DO NOTHING;" >> "$output_file" } # Generate async_jobs INSERT based on schema existence for PostgreSQL @@ -2427,6 +2664,49 @@ generate_routing_targets_insert_sqlite() { echo "INSERT INTO routing_targets (rule_id, provider, model, key_id, weight) VALUES ('rule-migration-test-2', NULL, NULL, NULL, 0.3) ON CONFLICT DO NOTHING;" >> "$output_file" } +# Generate governance_pricing_overrides INSERT for PostgreSQL +# This table was added in v1.5.0 as part of the custom pricing refactor. +# Two rows: one global (no FK deps) and one virtual_key-scoped (references vk-migration-test-1). +generate_pricing_overrides_insert_postgres() { + local now="$1" + local output_file="$2" + + # Check if the table exists + if ! column_exists_postgres "governance_pricing_overrides" "id"; then + return + fi + + echo "" >> "$output_file" + echo "-- governance_pricing_overrides (scoped pricing overrides - added in v1.5.0, dynamically generated)" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-001', 'Migration Test Override Global', 'global', NULL, NULL, NULL, 'exact', 'gpt-4', '[]', '{\"input_cost_per_token\": 0.00001}', 'po-hash-001', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-002', 'Migration Test Override VK', 'virtual_key', 'vk-migration-test-1', NULL, NULL, 'prefix', 'claude', '[]', '{\"output_cost_per_token\": 0.00002}', 'po-hash-002', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + +# Generate governance_pricing_overrides INSERT for SQLite +# This table was added in v1.5.0 as part of the custom pricing refactor. +generate_pricing_overrides_insert_sqlite() { + local now="$1" + local output_file="$2" + local config_db="$3" + + # Check if the table exists in the database + if [ ! -f "$config_db" ]; then + return + fi + + local table_exists + table_exists=$(sqlite3 "$config_db" "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='governance_pricing_overrides';" 2>/dev/null || echo "0") + + if [ "$table_exists" != "1" ]; then + return + fi + + echo "" >> "$output_file" + echo "-- governance_pricing_overrides (scoped pricing overrides - added in v1.5.0, dynamically generated)" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-001', 'Migration Test Override Global', 'global', NULL, NULL, NULL, 'exact', 'gpt-4', '[]', '{\"input_cost_per_token\": 0.00001}', 'po-hash-001', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-002', 'Migration Test Override VK', 'virtual_key', 'vk-migration-test-1', NULL, NULL, 'prefix', 'claude', '[]', '{\"output_cost_per_token\": 0.00002}', 'po-hash-002', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + # Validate faker column coverage for SQLite validate_faker_column_coverage_sqlite() { local faker_sql="$1" @@ -2645,6 +2925,7 @@ compare_postgres_snapshots() { # - network_config_json, concurrency_buffer_json, proxy_config_json, custom_provider_config_json: # JSON fields that get normalized with default values during migration # - budget_id, rate_limit_id: governance fields that may be reset or initialized during migrations + # - virtual_key_id, provider_config_id: new FK columns on governance_budgets (added by multi-budget migration) # - status, description: key validation runs after migration, updating these fields # for invalid/test keys (e.g., status becomes "list_models_failed") local ignore_columns="updated_at config_hash created_at models_json weight allowed_models network_config_json concurrency_buffer_json proxy_config_json custom_provider_config_json budget_id rate_limit_id status description" @@ -2701,6 +2982,24 @@ compare_postgres_snapshots() { if [ "$table" = "routing_rules" ]; then dropped_columns="$dropped_columns provider model" fi + # azure_deployments_json, vertex_deployments_json, bedrock_deployments_json, replicate_deployments_json + # (dropped from config_keys - migrated to provider-level deployment config) + if [ "$table" = "config_keys" ]; then + dropped_columns="$dropped_columns azure_deployments_json vertex_deployments_json bedrock_deployments_json replicate_deployments_json" + fi + # budget_id (dropped from governance_virtual_keys and governance_virtual_key_provider_configs + # in add_multi_budget_tables - ownership moved to governance_budgets.virtual_key_id/provider_config_id) + if [ "$table" = "governance_virtual_keys" ] || [ "$table" = "governance_virtual_key_provider_configs" ]; then + dropped_columns="$dropped_columns budget_id" + fi + # calendar_aligned (dropped from governance_budgets in add_multi_budget_tables - moved to governance_virtual_keys.calendar_aligned) + if [ "$table" = "governance_budgets" ]; then + dropped_columns="$dropped_columns calendar_aligned" + fi + # enable_litellm_fallbacks (dropped from config_client in latest cut - behavior moved elsewhere) + if [ "$table" = "config_client" ]; then + dropped_columns="$dropped_columns enable_litellm_fallbacks" + fi local before_col_array IFS=',' read -ra before_col_array <<< "$before_columns" @@ -2761,7 +3060,12 @@ compare_postgres_snapshots() { local col_idx=1 for col in "${before_col_array[@]}"; do # Skip columns that are expected to change - if [[ " $ignore_columns " == *" $col "* ]]; then + # virtual_key_id, provider_config_id: only ignore on governance_budgets (new FK columns from multi-budget migration) + local table_ignore_columns="$ignore_columns" + if [ "$table" = "governance_budgets" ]; then + table_ignore_columns="$table_ignore_columns virtual_key_id provider_config_id" + fi + if [[ " $table_ignore_columns " == *" $col "* ]]; then col_idx=$((col_idx + 1)) continue fi @@ -2852,6 +3156,84 @@ compare_postgres_snapshots() { # Validation Functions (simplified, uses snapshots) # ============================================================================ +# verify_budget_migration checks that the multi-budget FK migration correctly +# moved budget ownership from VK/ProviderConfig budget_id columns to +# governance_budgets.virtual_key_id / governance_budgets.provider_config_id +verify_budget_migration_postgres() { + log_info "Verifying budget migration (budget_id → virtual_key_id/provider_config_id)..." + local failed=0 + + # Check: budget-migration-test-1 was linked to vk-migration-test-1 via budget_id + # After migration, governance_budgets.virtual_key_id should be set + local vk_budget_count + vk_budget_count=$(run_postgres_scalar "SELECT COUNT(*) FROM governance_budgets WHERE id = 'budget-migration-test-1' AND virtual_key_id = 'vk-migration-test-1'") + if [ "$vk_budget_count" = "1" ]; then + log_info " VK budget migration: budget-migration-test-1 → vk-migration-test-1 ✓" + else + log_warn " VK budget migration: budget-migration-test-1 virtual_key_id not set (count=$vk_budget_count) — may be expected if old version didn't have budget_id on VK" + fi + + # Check: budget-migration-test-2 was linked to provider config via budget_id + # After migration, governance_budgets.provider_config_id should be set + local pc_budget_count + pc_budget_count=$(run_postgres_scalar "SELECT COUNT(*) FROM governance_budgets WHERE id = 'budget-migration-test-2' AND provider_config_id IS NOT NULL") + if [ "$pc_budget_count" = "1" ]; then + log_info " PC budget migration: budget-migration-test-2 → provider_config ✓" + else + log_warn " PC budget migration: budget-migration-test-2 provider_config_id not set (count=$pc_budget_count) — may be expected if old version didn't have budget_id on PC" + fi + + # Check: virtual_key_id and provider_config_id columns exist on governance_budgets + local has_vk_col + has_vk_col=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_budgets' AND column_name = 'virtual_key_id'") + if [ "$has_vk_col" = "1" ]; then + log_info " Column governance_budgets.virtual_key_id exists ✓" + else + log_error " Column governance_budgets.virtual_key_id MISSING!" + failed=1 + fi + + local has_pc_col + has_pc_col=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_budgets' AND column_name = 'provider_config_id'") + if [ "$has_pc_col" = "1" ]; then + log_info " Column governance_budgets.provider_config_id exists ✓" + else + log_error " Column governance_budgets.provider_config_id MISSING!" + failed=1 + fi + + # Check: budget_id column should be dropped from governance_virtual_keys + local vk_has_budget_id + vk_has_budget_id=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_virtual_keys' AND column_name = 'budget_id'") + if [ "$vk_has_budget_id" = "0" ]; then + log_info " Column governance_virtual_keys.budget_id dropped ✓" + else + log_error " Column governance_virtual_keys.budget_id still exists!" + failed=1 + fi + + # Check: budget_id column should be dropped from governance_virtual_key_provider_configs + local pc_has_budget_id + pc_has_budget_id=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_virtual_key_provider_configs' AND column_name = 'budget_id'") + if [ "$pc_has_budget_id" = "0" ]; then + log_info " Column governance_virtual_key_provider_configs.budget_id dropped ✓" + else + log_error " Column governance_virtual_key_provider_configs.budget_id still exists!" + failed=1 + fi + + # Check: junction tables should not exist + local junction_vk + junction_vk=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'governance_virtual_key_budgets'") + if [ "$junction_vk" = "0" ]; then + log_info " Junction table governance_virtual_key_budgets dropped ✓" + else + log_warn " Junction table governance_virtual_key_budgets still exists (may not have existed in old version)" + fi + + return $failed +} + validate_postgres_data() { local before_snapshot="$1" local after_snapshot="$2" @@ -3015,7 +3397,7 @@ EOF --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$server_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$server_log" 120; then + if ! wait_for_bifrost "$server_log" 300; then log_error "Failed to start bifrost $version" cat "$server_log" 2>/dev/null || true stop_bifrost @@ -3056,7 +3438,7 @@ EOF "$current_binary" --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$current_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$current_log" 120; then + if ! wait_for_bifrost "$current_log" 300; then log_error "Current version failed to start after migrating from $version" cat "$current_log" stop_bifrost @@ -3101,6 +3483,13 @@ EOF return 1 fi + # STEP 6: Verify budget migration (budget_id → virtual_key_id/provider_config_id) + if ! verify_budget_migration_postgres; then + log_error "Budget migration verification failed after migration from $version" + stop_bifrost + return 1 + fi + stop_bifrost log_info "Migration from $version: SUCCESS" done @@ -3207,7 +3596,7 @@ EOF --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$server_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$server_log" 120; then + if ! wait_for_bifrost "$server_log" 300; then log_error "Failed to start bifrost $version" cat "$server_log" 2>/dev/null || true stop_bifrost @@ -3247,7 +3636,7 @@ EOF "$current_binary" --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$current_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$current_log" 120; then + if ! wait_for_bifrost "$current_log" 300; then log_error "Current version failed to start after migrating from $version" cat "$current_log" stop_bifrost @@ -3321,4 +3710,4 @@ main() { exit $exit_code } -main "$@" +main "$@" \ No newline at end of file diff --git a/.github/workflows/scripts/schemasync/go.mod b/.github/workflows/scripts/schemasync/go.mod new file mode 100644 index 0000000000..0b8a2eee00 --- /dev/null +++ b/.github/workflows/scripts/schemasync/go.mod @@ -0,0 +1,10 @@ +module github.com/maximhq/bifrost/tools/schema-sync + +go 1.26.2 + +require golang.org/x/tools v0.30.0 + +require ( + golang.org/x/mod v0.23.0 // indirect + golang.org/x/sync v0.11.0 // indirect +) diff --git a/.github/workflows/scripts/schemasync/go.sum b/.github/workflows/scripts/schemasync/go.sum new file mode 100644 index 0000000000..68d0914a62 --- /dev/null +++ b/.github/workflows/scripts/schemasync/go.sum @@ -0,0 +1,8 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= diff --git a/.github/workflows/scripts/schemasync/main.go b/.github/workflows/scripts/schemasync/main.go new file mode 100644 index 0000000000..ff81e0902f --- /dev/null +++ b/.github/workflows/scripts/schemasync/main.go @@ -0,0 +1,1075 @@ +// schemasync validates that Bifrost Go config types stay in sync with +// transports/config.schema.json. +// +// Starting from a configured entry-point type (default: ConfigData in +// transports/bifrost-http/lib), it recursively walks every nested struct +// field via go/types. For each field it verifies: +// +// 1. The json:"X" tag has a corresponding property in config.schema.json at +// the propagated schema path (handling $ref, allOf, oneOf, if/then/else). +// 2. If the field's Go type is a named string type with const declarations, +// the set of Go constant values matches the schema's enum array. +// +// Exit 0 on full agreement, 1 on any mismatch. +package main + +import ( + "encoding/json" + "flag" + "fmt" + "go/constant" + "go/types" + "os" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +type entrypoint struct { + pkg string // Go import path + typeName string // exported type name + schemaPath string // JSON pointer path in config.schema.json (e.g. "/properties") + moduleDir string // directory (relative to --pkg-root) that contains the go.mod +} + +var entrypoints = []entrypoint{ + { + pkg: "github.com/maximhq/bifrost/transports/bifrost-http/lib", + typeName: "ConfigData", + schemaPath: "", // root schema node — collectProperties will find .properties + moduleDir: "transports", + }, +} + +// Schema properties that intentionally have no Go counterpart and vice versa. +// Key is a JSON pointer path; value is a short reason. +var ignoreSchemaProps = map[string]string{ + "/properties/$schema": "JSON schema self-reference", + // GORM foreignKey slice relations that ARE user-submittable config input. + // Go-side: schemasync skips them via the gorm-tag filter; schema-side: + // these entries prevent the missing-in-go warning for them. + "/properties/governance/properties/virtual_keys/items/properties/provider_configs": "gorm fk slice; user-submittable", + "/properties/governance/properties/virtual_keys/items/properties/mcp_configs": "gorm fk slice; user-submittable", + "/properties/governance/properties/routing_rules/items/properties/targets": "gorm fk slice; user-submittable", + // MCP headers map — documented escape hatch is envFrom: + // plus env.X references in values; no chart-native secretRef. + "/properties/mcp/properties/client_configs/items/properties/headers/additionalProperties": "documented envFrom pattern", + // Object-storage identity fields (bucket/region/endpoint/project_id) are + // EnvVar-typed for flexibility but are not inherently secret. Operators + // can write `env.MY_VAR` in values and use envFrom to inject. Access + // keys, session tokens, and credentials DO have chart-native secret + // support via `storage.logsStore.objectStorage.existingSecret`. + "/properties/logs_store/properties/object_storage/properties/bucket": "not a secret; env.X + envFrom pattern", + "/properties/logs_store/properties/object_storage/properties/region": "not a secret; env.X + envFrom pattern", + "/properties/logs_store/properties/object_storage/properties/endpoint": "not a secret; env.X + envFrom pattern", + "/properties/logs_store/properties/object_storage/properties/project_id": "not a secret; env.X + envFrom pattern", +} + +// ignoreGoFields keys are "schemaPath|fieldName"; value is the reason. +var ignoreGoFields = map[string]string{ + "|auth_config": "deprecated; moved to governance.auth_config", +} + +// ignoreGoFieldNames are field names (regardless of parent path) that are +// DB bookkeeping or runtime-derived — never part of user-submitted config. +var ignoreGoFieldNames = map[string]string{ + "created_at": "DB bookkeeping", + "updated_at": "DB bookkeeping", + "config_hash": "internal hash", + "status": "runtime-derived", + "state": "runtime-derived", +} + +// opaqueLeafTypes are named Go types that have custom JSON marshalling and +// should be treated as leaves. The walker does NOT recurse into their fields, +// and they are collected for downstream checks (e.g., EnvVar → helm secret). +var opaqueLeafTypes = map[string]string{ + "github.com/maximhq/bifrost/core/schemas.EnvVar": "env-aware string; custom JSON", +} + +// envVarLocation records where an EnvVar-typed field appears in config.json +// so a downstream pass can confirm the helm chart supports Secret-backed +// injection (existingSecret / secretRef / env.BIFROST_*) for that path. +type envVarLocation struct { + schemaPath string + goPath string +} + +// Finding categorises every issue the tool surfaces so the final report can +// group by category and render as a table. +type Finding struct { + Category string // e.g. "missing-in-schema", "missing-in-go", "enum-drift", "envvar-no-secret" + Severity string // "ERROR" or "WARN" + Path string // schema path or enum path + Detail string // field name, Go path, missing/extra values, etc. + Go string // Go-side location (package.Type.Field) +} + +type checker struct { + schema map[string]any + pkgs map[string]*packages.Package // path → pkg + // enumConsts[namedType] -> list of string values found in any loaded package + enumConsts map[string][]string + // visited type names to break cycles + visited map[string]bool + // envVarFields records where EnvVar types occur, for downstream checks + envVarFields []envVarLocation + findings []Finding +} + +func main() { + schemaFlag := flag.String("schema", "transports/config.schema.json", "path to config.schema.json") + pkgDir := flag.String("pkg-root", ".", "repo root used as packages.Load dir") + helmValuesFlag := flag.String("helm-values", "helm-charts/bifrost/values.schema.json", "path to helm values.schema.json (for EnvVar secret-support check)") + helmHelpersFlag := flag.String("helm-helpers", "helm-charts/bifrost/templates/_helpers.tpl", "path to helm _helpers.tpl (for env.BIFROST_* emission detection)") + flag.Parse() + + schemaBytes, err := os.ReadFile(*schemaFlag) + if err != nil { + fmt.Fprintf(os.Stderr, "read schema: %v\n", err) + os.Exit(2) + } + var schema map[string]any + if err := json.Unmarshal(schemaBytes, &schema); err != nil { + fmt.Fprintf(os.Stderr, "parse schema: %v\n", err) + os.Exit(2) + } + + // Group entrypoints by moduleDir so we load each module's package graph once. + byModule := map[string][]entrypoint{} + orderedMods := []string{} + for _, e := range entrypoints { + if _, seen := byModule[e.moduleDir]; !seen { + orderedMods = append(orderedMods, e.moduleDir) + } + byModule[e.moduleDir] = append(byModule[e.moduleDir], e) + } + absRoot, err := filepath.Abs(*pkgDir) + if err != nil { + fmt.Fprintf(os.Stderr, "abs pkg-root: %v\n", err) + os.Exit(2) + } + // Always use the repo's go.work so local modules resolve against each + // other (not against registry tarballs). The tool refuses to run without + // go.work — that's the only configuration bifrost is tested against. + goworkPath := filepath.Join(absRoot, "go.work") + if _, err := os.Stat(goworkPath); err != nil { + fmt.Fprintf(os.Stderr, "schemasync requires go.work at %s: %v\n", goworkPath, err) + os.Exit(2) + } + + allPkgs := map[string]*packages.Package{} + for _, mod := range orderedMods { + modDir := filepath.Join(absRoot, mod) + env := append([]string{}, os.Environ()...) + env = append(env, "GOWORK="+goworkPath) + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedFiles, + Dir: modDir, + Env: env, + } + imports := []string{} + for _, e := range byModule[mod] { + imports = append(imports, e.pkg) + } + pkgs, err := packages.Load(cfg, imports...) + if err != nil { + fmt.Fprintf(os.Stderr, "load %s: %v\n", mod, err) + os.Exit(2) + } + hadLoadErr := false + packages.Visit(pkgs, nil, func(p *packages.Package) { + for _, e := range p.Errors { + fmt.Fprintln(os.Stderr, e) + hadLoadErr = true + } + }) + if hadLoadErr { + os.Exit(2) + } + for k, v := range collectPkgs(pkgs) { + allPkgs[k] = v + } + } + + c := &checker{ + schema: schema, + pkgs: allPkgs, + enumConsts: map[string][]string{}, + visited: map[string]bool{}, + } + c.collectConsts() + + for _, e := range entrypoints { + p := c.pkgs[e.pkg] + if p == nil { + c.add(Finding{Category: "entrypoint", Severity: "ERROR", Detail: "package not loaded: " + e.pkg}) + continue + } + obj := p.Types.Scope().Lookup(e.typeName) + if obj == nil { + c.add(Finding{Category: "entrypoint", Severity: "ERROR", Detail: fmt.Sprintf("type %s not found in %s", e.typeName, e.pkg)}) + continue + } + named, ok := obj.Type().(*types.Named) + if !ok { + c.add(Finding{Category: "entrypoint", Severity: "ERROR", Detail: fmt.Sprintf("%s.%s is not a named type", e.pkg, e.typeName)}) + continue + } + c.walkType(named, e.schemaPath, fmt.Sprintf("%s.%s", e.pkg, e.typeName)) + } + + // EnvVar → helm-chart secret-support pass. For each Go field typed as + // schemas.EnvVar, the helm chart must either (a) emit an env.BIFROST_* + // placeholder for that JSON path via _helpers.tpl, or (b) expose a + // secretRef/existingSecret knob in values.schema.json at the equivalent + // camelCase location. If neither, warn. + c.checkEnvVarHelmSupport(*helmValuesFlag, *helmHelpersFlag) + + printReport(os.Stderr, c.findings) + errCount := c.countErrs() + warnCount := c.countWarns() + if errCount > 0 { + fmt.Fprintf(os.Stderr, "\nschemasync: %d errors, %d warnings\n", errCount, warnCount) + os.Exit(1) + } + fmt.Fprintf(os.Stderr, "\nschemasync: OK (%d warnings)\n", warnCount) +} + +// printReport groups findings by category and prints a markdown-style table +// for each non-empty group. +func printReport(w interface{ Write([]byte) (int, error) }, findings []Finding) { + if len(findings) == 0 { + return + } + groups := map[string][]Finding{} + order := []string{} + for _, f := range findings { + if _, ok := groups[f.Category]; !ok { + order = append(order, f.Category) + } + groups[f.Category] = append(groups[f.Category], f) + } + titles := map[string]string{ + "missing-in-schema": "Missing in config.schema.json (Go has field, schema doesn't) — ERRORS", + "missing-in-go": "Missing in Go (schema has property, ConfigData doesn't) — WARNINGS", + "enum-drift": "Enum drift (Go constants vs schema enum array)", + "enum-no-schema": "Go enum types with no schema `enum` constraint — WARNINGS", + "envvar-no-secret": "EnvVar fields lacking chart-native Secret support — WARNINGS", + "schema-path-not-found": "Schema path not found for a walked Go type — ERRORS", + "entrypoint": "Entrypoint problems — ERRORS", + } + for _, cat := range order { + items := groups[cat] + title := titles[cat] + if title == "" { + title = cat + } + fmt.Fprintf(w.(interface{ Write([]byte) (int, error) }), "\n### %s (%d)\n\n", title, len(items)) + // Pick columns based on category for readability. + switch cat { + case "missing-in-schema", "schema-path-not-found": + renderTable(w, []string{"severity", "schema path", "Go location"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Go}) + } + return out + }()) + case "missing-in-go": + renderTable(w, []string{"severity", "schema path", "property", "Go parent"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Detail, f.Go}) + } + return out + }()) + case "enum-drift", "enum-no-schema": + renderTable(w, []string{"severity", "enum path", "drift", "Go location"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Detail, f.Go}) + } + return out + }()) + case "envvar-no-secret": + renderTable(w, []string{"severity", "config path", "Go location", "note"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Go, f.Detail}) + } + return out + }()) + default: + renderTable(w, []string{"severity", "detail"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Detail}) + } + return out + }()) + } + } +} + +// renderTable writes a markdown table. Truncates long cells to keep width sane. +func renderTable(w interface{ Write([]byte) (int, error) }, headers []string, rows [][]string) { + const maxCol = 80 + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + truncate := func(s string) string { + if len(s) <= maxCol { + return s + } + return s[:maxCol-1] + "…" + } + trimmed := make([][]string, len(rows)) + for i, r := range rows { + trimmed[i] = make([]string, len(r)) + for j, cell := range r { + trimmed[i][j] = truncate(cell) + if j < len(widths) && len(trimmed[i][j]) > widths[j] { + widths[j] = len(trimmed[i][j]) + } + } + } + writeRow := func(cells []string) { + var sb strings.Builder + sb.WriteString("| ") + for i, c := range cells { + sb.WriteString(c) + if pad := widths[i] - len(c); pad > 0 { + sb.WriteString(strings.Repeat(" ", pad)) + } + sb.WriteString(" | ") + } + sb.WriteString("\n") + _, _ = w.Write([]byte(sb.String())) + } + writeRow(headers) + sep := make([]string, len(headers)) + for i := range headers { + sep[i] = strings.Repeat("-", widths[i]) + } + writeRow(sep) + for _, r := range trimmed { + writeRow(r) + } +} + +// checkEnvVarHelmSupport verifies that every Go field of type schemas.EnvVar +// has a way to be sourced from a Kubernetes secret via the helm chart. Proof +// of support is any of: +// +// 1. An `env.BIFROST_*` string literal appears in _helpers.tpl (indicating +// a rewrite is wired up for the corresponding config path), OR +// 2. values.schema.json declares a `secretRef` or `existingSecret` object +// at the camelCase equivalent of the schema path. +// +// Neither heuristic is perfect — this is a structural review aid, not a +// proof. Treat misses as warnings so they don't block CI on borderline cases. +func (c *checker) checkEnvVarHelmSupport(valuesPath, helpersPath string) { + helpersBytes, err := os.ReadFile(helpersPath) + if err != nil { + c.add(Finding{Category: "envvar-no-secret", Severity: "WARN", Detail: fmt.Sprintf("could not read helm helpers %s: %v — skipping EnvVar helm-support check", helpersPath, err)}) + return + } + helpers := string(helpersBytes) + // Extract every env.BIFROST_* token mentioned in _helpers.tpl. + envBifrostMentions := map[string]bool{} + for _, line := range strings.Split(helpers, "\n") { + // crude extraction: look for "env.BIFROST_X" substrings + idx := 0 + for idx < len(line) { + k := strings.Index(line[idx:], "env.BIFROST_") + if k < 0 { + break + } + start := idx + k + end := start + for end < len(line) { + ch := line[end] + if ch == '"' || ch == ' ' || ch == '\t' || ch == '}' || ch == ')' { + break + } + end++ + } + envBifrostMentions[line[start:end]] = true + idx = end + } + } + + valuesBytes, err := os.ReadFile(valuesPath) + hasValues := err == nil + var valuesSchema map[string]any + if hasValues { + _ = json.Unmarshal(valuesBytes, &valuesSchema) + } + + for _, loc := range c.envVarFields { + // Heuristic 1: any env.BIFROST_* is present in helpers — broad acceptance. + // We can't easily map a specific EnvVar field to a specific env var + // without per-field config, so we just check that the helpers file + // has AT LEAST ONE envBifrost mention that maps to this field's path. + // To make this stricter, we look for a helpers line mentioning either + // the camelCase field's parent path or an env var matching it. + camel := schemaPathToCamelCase(loc.schemaPath) + matched := false + // Heuristic 2: values.schema.json declares secretRef under the parent path. + if hasValues && valuesSchema != nil { + if hasSecretRefAt(valuesSchema, camel) { + matched = true + } + } + if !matched && len(envBifrostMentions) > 0 { + // Fall back to "some envBifrost wiring exists somewhere" — we flag it + // as a weaker hit so maintainers know to verify the mapping manually. + // Do not accept purely from presence; require a name-similarity match. + tail := lastSchemaComponent(loc.schemaPath) + for mention := range envBifrostMentions { + up := strings.ToUpper(tail) + if strings.Contains(mention, "_"+up) || strings.HasSuffix(mention, up) { + matched = true + break + } + } + } + if !matched { + if _, ignored := ignoreSchemaProps[loc.schemaPath]; ignored { + continue + } + c.add(Finding{ + Category: "envvar-no-secret", + Severity: "WARN", + Path: loc.schemaPath, + Detail: "helm has no secretRef/existingSecret at " + camel + " or parent", + Go: loc.goPath, + }) + } + } +} + +// schemaPathToCamelCase converts a JSON pointer like +// "/properties/governance/properties/auth_config/properties/admin_username" +// into a best-effort camelCase helm values path like +// "properties.bifrost.properties.governance.properties.authConfig.properties.adminUsername". +func schemaPathToCamelCase(p string) string { + parts := strings.Split(strings.TrimPrefix(p, "/"), "/") + out := []string{"properties", "bifrost"} + for _, part := range parts { + if part == "" { + continue + } + if part == "properties" { + out = append(out, "properties") + continue + } + out = append(out, snakeToCamel(part)) + } + return strings.Join(out, ".") +} + +func snakeToCamel(s string) string { + parts := strings.Split(s, "_") + for i := 1; i < len(parts); i++ { + if parts[i] != "" { + parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] + } + } + return strings.Join(parts, "") +} + +func lastSchemaComponent(p string) string { + parts := strings.Split(p, "/") + for i := len(parts) - 1; i >= 0; i-- { + if parts[i] != "" && parts[i] != "properties" { + return parts[i] + } + } + return "" +} + +// hasSecretRefAt returns true if EITHER (a) the target subtree declares a +// secretRef/existingSecret/*Secret knob inside its own "properties", OR +// (b) a SIBLING of the target (at the same properties-map level) is named +// "Secret" / "secretRef" / "existingSecret" / has "Secret" suffix. +// Sibling match is how the helm chart's encryptionKey + encryptionKeySecret +// pattern works: the Secret-source knob is a sibling of the field itself. +func hasSecretRefAt(schema map[string]any, dotted string) bool { + parts := strings.Split(dotted, ".") + var cur any = schema + var propsAtTarget map[string]any // map in which the last non-"properties" part lives + var targetName string + for _, p := range parts { + m, ok := cur.(map[string]any) + if !ok { + return false + } + // Resolve $ref at this node before descending. + if ref, ok := m["$ref"].(string); ok && strings.HasPrefix(ref, "#/") { + resolved := jsonPointerGet(schema, strings.TrimPrefix(ref, "#/")) + if rm, ok := resolved.(map[string]any); ok { + m = rm + } + } + if p != "properties" { + propsAtTarget = m + targetName = p + } + next, present := m[p] + if !present { + break + } + cur = next + } + // (a) target itself declares a Secret knob in its own properties. + if m, ok := cur.(map[string]any); ok && secretRefPresent(m) { + return true + } + // (b) a sibling of target matches Secret or a generic Secret knob. + if propsAtTarget != nil && targetName != "" { + for k := range propsAtTarget { + if k == targetName { + continue + } + if k == "secretRef" || k == "existingSecret" { + return true + } + if strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "SecretRef") { + return true + } + } + } + return false +} + +// jsonPointerGet resolves a /-delimited JSON Pointer into a schema root. +// Used by hasSecretRefAt to follow $ref entries in helm values.schema.json. +func jsonPointerGet(root any, pointer string) any { + if pointer == "" { + return root + } + parts := strings.Split(pointer, "/") + cur := root + for _, p := range parts { + p = strings.ReplaceAll(p, "~1", "/") + p = strings.ReplaceAll(p, "~0", "~") + m, ok := cur.(map[string]any) + if !ok { + return nil + } + cur = m[p] + if cur == nil { + return nil + } + } + return cur +} + +func secretRefPresent(m map[string]any) bool { + if m == nil { + return false + } + props, ok := m["properties"].(map[string]any) + if !ok { + return false + } + for k := range props { + if k == "secretRef" || k == "existingSecret" || k == "encryptionKeySecret" { + return true + } + if strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "SecretRef") { + return true + } + } + return false +} + +func collectPkgs(roots []*packages.Package) map[string]*packages.Package { + out := map[string]*packages.Package{} + packages.Visit(roots, nil, func(p *packages.Package) { + out[p.PkgPath] = p + }) + return out +} + +func (c *checker) add(f Finding) { c.findings = append(c.findings, f) } + +// countErrs returns the number of ERROR-severity findings. +func (c *checker) countErrs() int { + n := 0 + for _, f := range c.findings { + if f.Severity == "ERROR" { + n++ + } + } + return n +} + +// countWarns returns the number of WARN-severity findings. +func (c *checker) countWarns() int { + n := 0 + for _, f := range c.findings { + if f.Severity == "WARN" { + n++ + } + } + return n +} + +// collectConsts scans all loaded packages for `const X NamedStringType = "v"` +// and indexes them by namedType key "pkgpath.TypeName". +func (c *checker) collectConsts() { + for _, p := range c.pkgs { + if p.Types == nil { + continue + } + scope := p.Types.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + cnst, ok := obj.(*types.Const) + if !ok { + continue + } + named, ok := cnst.Type().(*types.Named) + if !ok { + continue + } + // Only named string types + basic, ok := named.Underlying().(*types.Basic) + if !ok || basic.Kind() != types.String { + continue + } + key := named.Obj().Pkg().Path() + "." + named.Obj().Name() + v := cnst.Val() + if v.Kind() != constant.String { + continue + } + c.enumConsts[key] = append(c.enumConsts[key], constant.StringVal(v)) + } + } + for k := range c.enumConsts { + sort.Strings(c.enumConsts[k]) + } +} + +// walkType recursively walks a struct type, verifying each json-tagged field +// has a schema counterpart at the propagated schemaPath. +func (c *checker) walkType(t types.Type, schemaPath, goPath string) { + t = deref(t) + named, _ := t.(*types.Named) + if named != nil { + key := named.Obj().Pkg().Path() + "." + named.Obj().Name() + // Treat opaque types (like schemas.EnvVar) as leaves. + if _, isOpaque := opaqueLeafTypes[key]; isOpaque { + if key == "github.com/maximhq/bifrost/core/schemas.EnvVar" { + c.envVarFields = append(c.envVarFields, envVarLocation{schemaPath, goPath}) + } + return + } + if c.visited[key+"@"+schemaPath] { + return + } + c.visited[key+"@"+schemaPath] = true + } + structType, ok := t.Underlying().(*types.Struct) + if !ok { + return + } + + schemaNode := c.resolveSchema(schemaPath) + if schemaNode == nil { + c.add(Finding{Category: "schema-path-not-found", Severity: "ERROR", Path: schemaPath, Go: goPath}) + return + } + + // Collect every property key reachable from this schema node across + // properties/allOf/oneOf/anyOf/if-then-else branches. + schemaProps := c.collectProperties(schemaNode, schemaPath) + + goFieldTags := map[string]*types.Var{} + for i := 0; i < structType.NumFields(); i++ { + f := structType.Field(i) + if !f.Exported() { + continue + } + tag := reflectTag(structType.Tag(i), "json") + if tag == "" || tag == "-" { + continue + } + name := strings.Split(tag, ",")[0] + if name == "" { + continue + } + // Skip GORM relational fields (populated from joins; never user-submitted). + // GORM relational fields (`foreignKey`, `many2many`) are populated + // from DB joins, not user-submitted config. Skip them from the walk so + // schemasync only compares user-input config against config.schema.json. + // Schema properties for these relations may still exist (validated at + // the missing-in-go layer); add the schema path to ignoreSchemaProps + // for deliberate exceptions (see below for `provider_configs`, etc.). + gormTag := reflectTag(structType.Tag(i), "gorm") + if strings.Contains(gormTag, "foreignKey") || strings.Contains(gormTag, "many2many") { + continue + } + goFieldTags[name] = f + } + + // Go-field → schema check + for name, f := range goFieldTags { + childPath := schemaPath + "/properties/" + name + if _, ignored := ignoreGoFields[schemaPath+"|"+name]; ignored { + continue + } + if _, ignored := ignoreGoFieldNames[name]; ignored { + continue + } + childSchema := schemaProps[name] + if childSchema == nil { + c.add(Finding{ + Category: "missing-in-schema", + Severity: "ERROR", + Path: schemaPath + "/properties/" + name, + Detail: name, + Go: goPath + "." + f.Name(), + }) + continue + } + // Recurse into field type + c.walkField(f.Type(), childSchema, childPath, goPath+"."+f.Name()) + } + + // Schema-key → Go field check (warnings; schema may legitimately be broader) + for name := range schemaProps { + if _, ignored := ignoreSchemaProps[schemaPath+"/properties/"+name]; ignored { + continue + } + if _, ignored := ignoreGoFields[schemaPath+"|"+name]; ignored { + continue + } + if _, ok := goFieldTags[name]; !ok { + c.add(Finding{ + Category: "missing-in-go", + Severity: "WARN", + Path: schemaPath + "/properties/" + name, + Detail: name, + Go: goPath, + }) + } + } +} + +// walkField dispatches based on the field's Go type. +func (c *checker) walkField(t types.Type, schemaNode map[string]any, schemaPath, goPath string) { + t = deref(t) + + // Named type → opaque-leaf check + enum check (if string const type) + if named, ok := t.(*types.Named); ok { + key := named.Obj().Pkg().Path() + "." + named.Obj().Name() + if _, isOpaque := opaqueLeafTypes[key]; isOpaque { + if key == "github.com/maximhq/bifrost/core/schemas.EnvVar" { + c.envVarFields = append(c.envVarFields, envVarLocation{schemaPath, goPath}) + } + return // do not recurse into opaque types + } + if goVals, hasConsts := c.enumConsts[key]; hasConsts && len(goVals) > 0 { + c.checkEnum(goVals, schemaNode, schemaPath, goPath, key) + } + } + + switch u := t.Underlying().(type) { + case *types.Struct: + // Recurse into named struct (anonymous structs are inlined below) + if _, ok := t.(*types.Named); ok { + c.walkType(t, schemaPath, goPath) + } else { + // Anonymous inline struct — rare but handle by walking tags + c.walkAnonymous(u, schemaPath, goPath) + } + case *types.Slice: + elem := u.Elem() + if _, isStruct := deref(elem).Underlying().(*types.Struct); isStruct { + itemsNode := c.resolveRef(schemaNode) + if _, ok := itemsNode["items"].(map[string]any); ok { + c.walkType(elem, schemaPath+"/items", goPath+"[]") + } + } + case *types.Array: + elem := u.Elem() + if _, isStruct := deref(elem).Underlying().(*types.Struct); isStruct { + itemsNode := c.resolveRef(schemaNode) + if _, ok := itemsNode["items"].(map[string]any); ok { + c.walkType(elem, schemaPath+"/items", goPath+"[]") + } + } + case *types.Map: + elem := u.Elem() + if _, isStruct := deref(elem).Underlying().(*types.Struct); isStruct { + node := c.resolveRef(schemaNode) + if _, ok := node["additionalProperties"].(map[string]any); ok { + c.walkType(elem, schemaPath+"/additionalProperties", goPath+"[]") + } + // If no additionalProperties/patternProperties, silently skip — schemas + // often describe provider-keyed maps via oneOf branches. + } + case *types.Basic, *types.Interface: + // Leaf — nothing to recurse into. + } +} + +// walkAnonymous handles anonymous (inline) struct fields — rare; we treat +// them as a struct walk at the same schemaPath. +func (c *checker) walkAnonymous(st *types.Struct, schemaPath, goPath string) { + // Not common in this codebase; fall back to flat tag-check. + schemaNode := c.resolveSchema(schemaPath) + if schemaNode == nil { + return + } + props := c.collectProperties(schemaNode, schemaPath) + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + if !f.Exported() { + continue + } + tag := reflectTag(st.Tag(i), "json") + if tag == "" || tag == "-" { + continue + } + name := strings.Split(tag, ",")[0] + if _, ok := props[name]; !ok { + c.add(Finding{ + Category: "missing-in-schema", + Severity: "ERROR", + Path: schemaPath + "/properties/" + name, + Detail: name, + Go: goPath + "." + f.Name(), + }) + } + } +} + +// checkEnum diffs Go string-const values against schema enum array. +func (c *checker) checkEnum(goVals []string, schemaNode map[string]any, schemaPath, goPath, typeKey string) { + node := c.resolveRef(schemaNode) + rawEnum, ok := node["enum"] + if !ok { + c.add(Finding{ + Category: "enum-no-schema", + Severity: "WARN", + Path: schemaPath, + Detail: fmt.Sprintf("%v (Go consts)", goVals), + Go: typeKey, + }) + return + } + enumArr, ok := rawEnum.([]any) + if !ok { + c.add(Finding{Category: "enum-drift", Severity: "ERROR", Path: schemaPath, Detail: "schema enum is not an array"}) + return + } + schemaSet := map[string]bool{} + for _, v := range enumArr { + if s, ok := v.(string); ok { + schemaSet[s] = true + } + } + goSet := map[string]bool{} + for _, v := range goVals { + goSet[v] = true + } + var missingInSchema, extraInSchema []string + for v := range goSet { + if !schemaSet[v] { + missingInSchema = append(missingInSchema, v) + } + } + for v := range schemaSet { + if !goSet[v] { + extraInSchema = append(extraInSchema, v) + } + } + sort.Strings(missingInSchema) + sort.Strings(extraInSchema) + if len(missingInSchema) > 0 { + c.add(Finding{ + Category: "enum-drift", + Severity: "ERROR", + Path: schemaPath, + Detail: fmt.Sprintf("schema missing Go consts %v", missingInSchema), + Go: goPath + " (" + typeKey + ")", + }) + } + if len(extraInSchema) > 0 { + c.add(Finding{ + Category: "enum-drift", + Severity: "WARN", + Path: schemaPath, + Detail: fmt.Sprintf("schema has %v with no Go const", extraInSchema), + Go: typeKey, + }) + } +} + +// collectProperties walks the schema subtree rooted at `node`, unioning +// property keys from the direct `properties`, and recursively from `allOf`, +// `oneOf`, `anyOf`, `then`, `else`. Handles $ref. +// Returns map of propertyName → subschema. +func (c *checker) collectProperties(node map[string]any, atPath string) map[string]map[string]any { + out := map[string]map[string]any{} + c.mergeProperties(out, node, atPath, map[string]bool{}) + return out +} + +func (c *checker) mergeProperties(out map[string]map[string]any, node map[string]any, atPath string, seen map[string]bool) { + if node == nil { + return + } + node = c.resolveRef(node) + if ref, ok := node["$ref"].(string); ok && seen[ref] { + return + } + if ref, ok := node["$ref"].(string); ok { + seen[ref] = true + } + if props, ok := node["properties"].(map[string]any); ok { + for k, v := range props { + if m, ok := v.(map[string]any); ok { + if _, already := out[k]; !already { + out[k] = m + } + } + } + } + for _, key := range []string{"allOf", "oneOf", "anyOf"} { + if arr, ok := node[key].([]any); ok { + for _, item := range arr { + if m, ok := item.(map[string]any); ok { + c.mergeProperties(out, m, atPath+"/"+key, seen) + } + } + } + } + for _, key := range []string{"then", "else"} { + if m, ok := node[key].(map[string]any); ok { + c.mergeProperties(out, m, atPath+"/"+key, seen) + } + } +} + +// resolveSchema walks a JSON-pointer path into c.schema, resolving $ref at +// each intermediate node. +func (c *checker) resolveSchema(path string) map[string]any { + parts := strings.Split(strings.TrimPrefix(path, "/"), "/") + var cur any = c.schema + for _, p := range parts { + if p == "" { + continue + } + m, ok := cur.(map[string]any) + if !ok { + return nil + } + m = c.resolveRef(m) + cur = m[unescapeJSONPointer(p)] + if cur == nil { + return nil + } + } + if m, ok := cur.(map[string]any); ok { + return c.resolveRef(m) + } + return nil +} + +// resolveRef follows a $ref pointer (recursively) to the final target node. +// $ref values are expected as "#/$defs/xxx" style JSON pointers. +func (c *checker) resolveRef(node map[string]any) map[string]any { + for i := 0; i < 16; i++ { + ref, ok := node["$ref"].(string) + if !ok { + return node + } + if !strings.HasPrefix(ref, "#/") { + return node // external refs unsupported + } + parts := strings.Split(strings.TrimPrefix(ref, "#/"), "/") + var cur any = c.schema + ok2 := true + for _, p := range parts { + m, isMap := cur.(map[string]any) + if !isMap { + ok2 = false + break + } + cur = m[unescapeJSONPointer(p)] + if cur == nil { + ok2 = false + break + } + } + if !ok2 { + return node + } + next, ok := cur.(map[string]any) + if !ok { + return node + } + node = next + } + return node +} + +func unescapeJSONPointer(s string) string { + s = strings.ReplaceAll(s, "~1", "/") + s = strings.ReplaceAll(s, "~0", "~") + return s +} + +// deref strips pointer wrappers to get the underlying type. +func deref(t types.Type) types.Type { + for { + p, ok := t.(*types.Pointer) + if !ok { + return t + } + t = p.Elem() + } +} + +// reflectTag parses a single struct-tag key; mirrors reflect.StructTag.Get. +func reflectTag(tag, key string) string { + for tag != "" { + for tag != "" && tag[0] == ' ' { + tag = tag[1:] + } + i := 0 + for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f { + i++ + } + if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' { + break + } + name := tag[:i] + tag = tag[i+1:] + i = 1 + for i < len(tag) && tag[i] != '"' { + if tag[i] == '\\' { + i++ + } + i++ + } + if i >= len(tag) { + break + } + val := tag[1:i] + tag = tag[i+1:] + if name == key { + return val + } + } + return "" +} diff --git a/.github/workflows/scripts/setup-go-workspace.sh b/.github/workflows/scripts/setup-go-workspace.sh index bc9a4d2854..29024fdaa3 100755 --- a/.github/workflows/scripts/setup-go-workspace.sh +++ b/.github/workflows/scripts/setup-go-workspace.sh @@ -20,14 +20,16 @@ fi go work init go work use ./core go work use ./framework +go work use ./plugins/compat go work use ./plugins/governance go work use ./plugins/jsonparser -go work use ./plugins/litellmcompat go work use ./plugins/logging go work use ./plugins/maxim go work use ./plugins/mocker go work use ./plugins/otel +go work use ./plugins/prompts go work use ./plugins/semanticcache go work use ./plugins/telemetry go work use ./transports +go work use ./cli echo "✅ Go workspace initialized" diff --git a/.github/workflows/scripts/test-docker-image.sh b/.github/workflows/scripts/test-docker-image.sh index 5d770fbd64..ac115394bf 100755 --- a/.github/workflows/scripts/test-docker-image.sh +++ b/.github/workflows/scripts/test-docker-image.sh @@ -212,8 +212,7 @@ cat > "$CONFIG_FILE" << 'CONFIGEOF' "enable_logging": true, "enforce_governance_header": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "encryption_key": "" } diff --git a/.github/workflows/scripts/validate-helm-config-fields.sh b/.github/workflows/scripts/validate-helm-config-fields.sh index ed7a17857d..7a78e1bf54 100755 --- a/.github/workflows/scripts/validate-helm-config-fields.sh +++ b/.github/workflows/scripts/validate-helm-config-fields.sh @@ -160,7 +160,7 @@ bifrost: enableLogging: true disableContentLogging: true disableDbPingsInHealth: true - logRetentionDays: 30 + logRetentionDays: 30 enforceGovernanceHeader: true allowDirectKeys: true maxRequestBodySizeMb: 50 @@ -168,6 +168,7 @@ bifrost: convertTextToChat: true convertChatToResponses: true shouldDropParams: true + shouldConvertParams: true prometheusLabels: - "team" - "env" @@ -206,6 +207,7 @@ assert_field_value 'client.max_request_body_size_mb' '.client.max_request_body_s assert_field_value 'client.compat.convert_text_to_chat' '.client.compat.convert_text_to_chat' 'true' assert_field_value 'client.compat.convert_chat_to_responses' '.client.compat.convert_chat_to_responses' 'true' assert_field_value 'client.compat.should_drop_params' '.client.compat.should_drop_params' 'true' +assert_field_value 'client.compat.should_convert_params' '.client.compat.should_convert_params' 'true' assert_field 'client.prometheus_labels' '.client.prometheus_labels' assert_field 'client.header_filter_config.allowlist' '.client.header_filter_config.allowlist' assert_field 'client.header_filter_config.denylist' '.client.header_filter_config.denylist' @@ -823,10 +825,10 @@ assert_field_value 'mcp client[0] tool_pricing.search' '.mcp.client_configs.[0]. assert_field_value 'mcp tool_manager_config.code_mode_binding_level' '.mcp.tool_manager_config.code_mode_binding_level' '"server"' ############################################################################### -# 8. Cluster, SAML, Load Balancer, Guardrails, Audit Logs +# 8. Cluster, SCIM, Load Balancer, Guardrails, Audit Logs ############################################################################### echo "" -echo -e "${CYAN}🌐 8/10 - Cluster, SAML, LB, Guardrails, Audit Logs${NC}" +echo -e "${CYAN}🌐 8/10 - Cluster, SCIM, LB, Guardrails, Audit Logs${NC}" echo "-----------------------------------------------------" cat > "$TMPDIR/values-cluster.yaml" << 'VALS' @@ -893,7 +895,8 @@ render_config "$TMPDIR/values-scim-okta.yaml" assert_field_value 'scim_config.enabled' '.scim_config.enabled' 'true' assert_field_value 'scim_config.provider' '.scim_config.provider' '"okta"' assert_field 'scim_config.config' '.scim_config.config' - +assert_field 'scim_config.config.apiToken' '.scim_config.config.apiToken' +assert_field 'scim_config.config.clientSecret' '.scim_config.config.clientSecret' # SCIM - Entra cat > "$TMPDIR/values-scim-entra.yaml" << 'VALS' image: @@ -916,6 +919,9 @@ VALS render_config "$TMPDIR/values-scim-entra.yaml" assert_field_value 'scim_config (entra) provider' '.scim_config.provider' '"entra"' assert_field 'scim_config (entra) config' '.scim_config.config' +assert_field_value 'scim_config (entra) enabled' '.scim_config.enabled' 'true' +assert_field 'scim_config (entra) config.tenantId' '.scim_config.config.tenantId' +assert_field 'scim_config (entra) config.clientId' '.scim_config.config.clientId' # Load Balancer cat > "$TMPDIR/values-lb.yaml" << 'VALS' @@ -1183,6 +1189,97 @@ assert_field_value 'logs_store.type (postgres)' '.logs_store.type' '"postgres"' assert_field_value 'logs_store.config.max_idle_conns' '.logs_store.config.max_idle_conns' '5' assert_field_value 'logs_store.config.max_open_conns' '.logs_store.config.max_open_conns' '50' +############################################################################### +# Object Storage (logsStore.objectStorage) +############################################################################### + +# S3 with inline credentials — exercises camelCase → snake_case mapping in _helpers.tpl +cat > "$TMPDIR/values-objstore-s3.yaml" << 'VALS' +image: + tag: v1.0.0 +storage: + mode: sqlite + configStore: + enabled: true + logsStore: + enabled: true + objectStorage: + enabled: true + type: s3 + bucket: "bifrost-logs" + prefix: "prod" + compress: true + region: "us-east-1" + endpoint: "https://minio.internal:9000" + accessKeyId: "AKIA..." + secretAccessKey: "secret" + roleArn: "arn:aws:iam::123:role/bifrost" + forcePathStyle: true +VALS + +render_config "$TMPDIR/values-objstore-s3.yaml" +assert_field_value 'logs_store.object_storage.type (s3)' '.logs_store.object_storage.type' '"s3"' +assert_field_value 'logs_store.object_storage.bucket' '.logs_store.object_storage.bucket' '"bifrost-logs"' +assert_field_value 'logs_store.object_storage.prefix' '.logs_store.object_storage.prefix' '"prod"' +assert_field_value 'logs_store.object_storage.compress' '.logs_store.object_storage.compress' 'true' +assert_field_value 'logs_store.object_storage.region' '.logs_store.object_storage.region' '"us-east-1"' +assert_field_value 'logs_store.object_storage.endpoint' '.logs_store.object_storage.endpoint' '"https://minio.internal:9000"' +assert_field_value 'logs_store.object_storage.access_key_id' '.logs_store.object_storage.access_key_id' '"AKIA..."' +assert_field_value 'logs_store.object_storage.secret_access_key' '.logs_store.object_storage.secret_access_key' '"secret"' +assert_field_value 'logs_store.object_storage.role_arn' '.logs_store.object_storage.role_arn' '"arn:aws:iam::123:role/bifrost"' +assert_field_value 'logs_store.object_storage.force_path_style' '.logs_store.object_storage.force_path_style' 'true' + +# S3 with existingSecret — exercises env.BIFROST_OBJECT_STORAGE_* substitution path +cat > "$TMPDIR/values-objstore-s3-secret.yaml" << 'VALS' +image: + tag: v1.0.0 +storage: + mode: sqlite + configStore: + enabled: true + logsStore: + enabled: true + objectStorage: + enabled: true + type: s3 + bucket: "bifrost-logs" + existingSecret: "bifrost-os-creds" + accessKeyIdKey: "access-key-id" + secretAccessKeyKey: "secret-access-key" + sessionTokenKey: "session-token" + roleArnKey: "role-arn" +VALS + +render_config "$TMPDIR/values-objstore-s3-secret.yaml" +assert_field_value 'logs_store.object_storage.access_key_id (env)' '.logs_store.object_storage.access_key_id' '"env.BIFROST_OBJECT_STORAGE_ACCESS_KEY_ID"' +assert_field_value 'logs_store.object_storage.secret_access_key (env)' '.logs_store.object_storage.secret_access_key' '"env.BIFROST_OBJECT_STORAGE_SECRET_ACCESS_KEY"' +assert_field_value 'logs_store.object_storage.session_token (env)' '.logs_store.object_storage.session_token' '"env.BIFROST_OBJECT_STORAGE_SESSION_TOKEN"' +assert_field_value 'logs_store.object_storage.role_arn (env)' '.logs_store.object_storage.role_arn' '"env.BIFROST_OBJECT_STORAGE_ROLE_ARN"' + +# GCS — exercises project_id + credentials_json mapping +cat > "$TMPDIR/values-objstore-gcs.yaml" << 'VALS' +image: + tag: v1.0.0 +storage: + mode: sqlite + configStore: + enabled: true + logsStore: + enabled: true + objectStorage: + enabled: true + type: gcs + bucket: "bifrost-gcs-bucket" + projectId: "my-gcp-project" + credentialsJson: "/etc/gcs/creds.json" +VALS + +render_config "$TMPDIR/values-objstore-gcs.yaml" +assert_field_value 'logs_store.object_storage.type (gcs)' '.logs_store.object_storage.type' '"gcs"' +assert_field_value 'logs_store.object_storage.bucket (gcs)' '.logs_store.object_storage.bucket' '"bifrost-gcs-bucket"' +assert_field_value 'logs_store.object_storage.project_id' '.logs_store.object_storage.project_id' '"my-gcp-project"' +assert_field_value 'logs_store.object_storage.credentials_json' '.logs_store.object_storage.credentials_json' '"/etc/gcs/creds.json"' + ############################################################################### # Summary ############################################################################### @@ -1200,4 +1297,4 @@ if [ "$TESTS_FAILED" -gt 0 ]; then else echo -e "${GREEN}✅ All config.json field validations passed!${NC}" exit 0 -fi +fi \ No newline at end of file diff --git a/.github/workflows/scripts/validate-helm-schema.sh b/.github/workflows/scripts/validate-helm-schema.sh index 98c36f3137..5f012bc5c8 100755 --- a/.github/workflows/scripts/validate-helm-schema.sh +++ b/.github/workflows/scripts/validate-helm-schema.sh @@ -196,7 +196,7 @@ else echo "✅ VLLM key config required fields match: [$HELM_VLLM_REQUIRED]" fi -# Check concurrency_config required fields (config calls this def concurrency_and_buffer_size) +# Check concurrency_and_buffer_size required fields (renamed from concurrency_config) CONFIG_CONCURRENCY_REQUIRED=$(jq -r '."$defs".concurrency_and_buffer_size.required // [] | sort | join(",")' "$CONFIG_SCHEMA" 2>/dev/null || echo "") HELM_CONCURRENCY_REQUIRED=$(jq -r '."$defs".concurrencyConfig.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") @@ -433,8 +433,10 @@ else echo "✅ MCP stdio config required fields match: [$CONFIG_MCP_STDIO_REQUIRED]" fi -# MCP websocket_config / http_config are Helm-only sub-structures; config.schema.json uses -# a flat connection_type + connection_string instead, so there is nothing to compare here. +# MCP websocket_config and http_config were removed from config.schema.json +# because the corresponding Go fields don't exist (MCP rendering uses +# connection_type + connection_string directly, not sub-object configs). +# Helm still declares them for user convenience — not a schema sync concern. echo "" echo "🔍 Checking required fields in SAML/SCIM config..." diff --git a/.github/workflows/scripts/validate-schema-sync.sh b/.github/workflows/scripts/validate-schema-sync.sh new file mode 100755 index 0000000000..0214d0679b --- /dev/null +++ b/.github/workflows/scripts/validate-schema-sync.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Validate that Go config types in transports/bifrost-http/lib/config.go +# stay in sync (fields + enum values) with transports/config.schema.json. +# Walks the type graph recursively via go/types rather than regex-parsing source. + +if command -v readlink >/dev/null 2>&1 && readlink -f "$0" >/dev/null 2>&1; then + SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +else + SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd -P)" +fi +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +TOOL_DIR="$SCRIPT_DIR/schemasync" + +cd "$REPO_ROOT" + +if ! command -v go >/dev/null 2>&1; then + echo "❌ go toolchain required for schema-sync validation" + exit 2 +fi + +# Ensure go.work exists at the repo root. schemasync's packages.Load needs +# it to resolve bifrost's local modules against each other. On fresh CI +# runners go.work is not checked in, so we provision it here inline. +# Sibling scripts (test-bifrost-http.sh etc.) call setup-go-workspace.sh +# via `source`, but that relies on the `return` builtin which has +# platform-dependent edge cases under `set -e`; we instead do the same +# work inline so this wrapper is self-contained. +if [ ! -f "$REPO_ROOT/go.work" ]; then + echo "🔧 Setting up Go workspace (go.work not found)..." + ( + cd "$REPO_ROOT" + go work init + for mod in ./core ./framework \ + ./plugins/compat ./plugins/governance ./plugins/jsonparser \ + ./plugins/logging ./plugins/maxim ./plugins/mocker \ + ./plugins/otel ./plugins/prompts ./plugins/semanticcache \ + ./plugins/telemetry \ + ./transports ./cli; do + if [ -f "$REPO_ROOT/$mod/go.mod" ]; then + go work use "$mod" + fi + done + ) + echo "✅ Go workspace initialized at $REPO_ROOT/go.work" +else + echo "🔍 Go workspace already exists at $REPO_ROOT/go.work, skipping initialization" +fi + +echo "🔍 Validating Go ↔ config.schema.json sync (recursive, AST-based)" +echo "==================================================================" + +# The schemasync tool is its own module (separate go.mod). Build it with +# GOWORK=off so the tool's deps (golang.org/x/tools) resolve against its +# own go.mod, not the repo's go.work. At runtime the tool itself sets +# GOWORK=/go.work when loading bifrost packages. +(cd "$TOOL_DIR" && GOWORK=off go build -o /tmp/schemasync .) +/tmp/schemasync \ + --schema "$REPO_ROOT/transports/config.schema.json" \ + --pkg-root "$REPO_ROOT" \ + --helm-values "$REPO_ROOT/helm-charts/bifrost/values.schema.json" \ + --helm-helpers "$REPO_ROOT/helm-charts/bifrost/templates/_helpers.tpl" diff --git a/.gitignore b/.gitignore index f911a0a153..b82e4df24f 100644 --- a/.gitignore +++ b/.gitignore @@ -118,4 +118,56 @@ terraform.tfstate.backup !*.tfvars.example # Bifrost benchmarking -bifrost-benchmarking \ No newline at end of file +bifrost-benchmarking + +# Tests +:memory: + +# Generated test TLS certs (created by tests/docker-compose.yml redis-certs-init) +tests/redis-certs/ + +# dependencies +ui/node_modules +ui/.pnp +.pnp.* +.yarn/* +!.yarn/patches +!.yarn/plugins +!.yarn/releases +!.yarn/versions + +# testing +ui/coverage + +# next.js +ui/.next/ +ui/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* + +# env files (can opt-in for committing if needed) +.env* + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +# auto-generated TanStack Router route tree +ui/app/routeTree.gen.ts + +.tanstack +.next \ No newline at end of file diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 0000000000..1d9b7831ba --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +22.12.0 diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 4da40ee345..0000000000 --- a/.prettierrc +++ /dev/null @@ -1,25 +0,0 @@ -{ - "root": true, - "printWidth": 140, - "singleQuote": false, - "bracketSpacing": true, - "semi": true, - "bracketSameLine": false, - "useTabs": true, - "tabWidth": 2, - "trailingComma": "all", - "plugins": [ - "prettier-plugin-tailwindcss" - ], - "pluginSearchDirs": [ - "./ui" - ], - "tailwindAttributes": [ - "buttonClassname" - ], - "tailwindFunctions": [ - "cn", - "classNames" - ], - "endOfLine": "lf" -} \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index 03fdd812ee..f6356421ea 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -103,9 +103,9 @@ bifrost/ │ ├── mocker/ # Mock responses for testing │ ├── jsonparser/ # JSON extraction utilities │ ├── maxim/ # Maxim observability -│ └── litellmcompat/ # LiteLLM SDK compatibility (HTTP transport) +│ └── compat/ # LiteLLM SDK compatibility (HTTP transport) │ -├── ui/ # Next.js web interface +├── ui/ # React + vite web interface │ ├── app/workspace/ # Feature pages (20+ workspace sections) │ ├── components/ # Shared React components │ └── lib/ # Constants, utilities, types @@ -641,10 +641,266 @@ Systematically address unresolved PR review comments. Uses GraphQL to get unreso ## Code Style - **Go**: `gofmt`/`goimports`. No custom linter config. -- **TypeScript/React**: Prettier. Next.js App Router. +- **TypeScript/React**: Oxfmt. TanStack Router. - **JSON tags**: `snake_case` matching provider API conventions. - **Error strings**: Lowercase, no trailing punctuation (Go convention). - **Provider types**: Prefixed with provider name in PascalCase (`AnthropicChatRequest`, `GeminiEmbeddingResponse`). - **Converter functions**: Pure — no side effects, no logging, no HTTP. - **Pool names**: Descriptive string passed to `pool.New()` (e.g., `"channel-message"`, `"response-stream"`). - **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. + +# Frontend Code Guidelines & Patterns + +This document defines the standards, structure, and best practices for writing frontend code in this project. + +--- + +## Tech Stack + +- **React** (with Vite) +- **TypeScript** +- **@tanstack/react-router** (type-safe routing) +- **Tailwind CSS v4** +- **Radix UI** (primitives) +- **Local UI component library** (`ui/components/ui/`) built on Radix primitives + +--- + +## Folder Structure + +```text + +/ui +├── app # Routes & pages +├── components # Shared components +│ └── ui # Core design system components +├── hooks # Custom React hooks +├── lib # Utilities, helpers, shared logic +└── app/enterprise # Enterprise-specific code (via symlink) + +``` + +### Rules + +- All frontend code must live inside `/ui` +- Routes and pages → `ui/app` +- Shared/reusable components → `ui/components` +- Core UI primitives → `ui/components/ui` +- Utilities and libraries → `ui/lib` +- Custom hooks → `ui/hooks` + +--- + +## Libraries & Usage + +### Core Libraries + +- `react` → UI library +- `typescript` → Type safety +- `tailwindcss` → Styling +- `@tanstack/react-router` → Routing + +### UI & Visualization + +- `@radix-ui/react-*` → UI primitives +- `ui/components/ui/*` → Project's Radix-based component system +- `recharts` → Charts +- `monaco-editor` → Code editor + +### Utilities + +- `date-fns` → Date/time formatting +- `nuqs` → Query param state management + +### Tooling + +- `Oxfmt` → Code formatting +- `vitest` → Testing + +--- + +## Routing Convention + +For every new route: + +```text + +ui/app// +├── layout.tsx # Route definition using createFileRoute +├── page.tsx # Page content +└── views/ # Optional: route-specific components + +``` + +### Rules + +- Folder name must match route name +- Always use `createFileRoute` in `layout.tsx` +- `page.tsx` should only handle composition (not heavy logic) +- Route-specific components go inside `views/` + +--- + +## Component Guidelines + +### Reusability First + +- Always check if similar components/functions already exist +- Prefer extending or refactoring existing code over duplication +- Only create new components if reuse is not feasible + +--- + +### Component Placement + +- Shared → `ui/components` +- Route-specific → `views/` inside route folder + +--- + +### JSX & Rendering + +- Avoid deeply nested conditional rendering +- Break complex UI into smaller components +- Keep components readable and maintainable + +--- + +### Lists & Keys + +- Always use **stable, unique keys** +- Never use array index as key (unless unavoidable) + +--- + +## React Best Practices + +- Avoid unnecessary or unstable dependencies in hooks +- Prevent infinite loops in `useEffect` +- Keep dependency arrays accurate and minimal +- Prefer derived state over duplicated state + +--- + +## State Management + +### Priority Order + +1. Query Params (`nuqs`) → for persistent/shareable state +2. Local State → for UI-only state +3. Redux → only when truly necessary + +--- + +### Query Params (`nuqs`) + +- Use for state that should persist across refresh/navigation +- Use proper parsers like `parseAsString` or `parseAsInteger` +- Do NOT mix query param state with local/redux state +- Follow a single consistent pattern across the codebase + +--- + +### Redux + +- Use only when global/shared state is required +- Avoid unnecessary slices +- Prefer simpler alternatives when possible + +--- + +### RTK Query (`@reduxjs/toolkit/query`) + +- Use for API calls and caching +- Use **granular tags** for cache invalidation +- Avoid invalidating entire datasets unnecessarily +- Implement **optimistic updates** where applicable + +--- + +## Forms + +We use: + +- `react-hook-form` +- `zod v4` (for schema validation) + +### Rules + +- Always define a Zod schema +- Include meaningful validation messages +- Prefer **inline field errors** (not toast notifications) +- Use `refine` / `superRefine` for complex validation +- Store schemas in: `ui/lib/types/schemas.ts` + +--- + +## Tables + +- Use `@tanstack/react-table` **only for large/complex datasets** +- For simple tables → build custom lightweight components +- Prioritize performance over abstraction + +--- + +## ⚡ Performance Guidelines + +- Lazy load heavy or rarely-used libraries +- Avoid unnecessary re-renders +- Split large components into smaller ones +- Keep bundle size minimal + +--- + +## Dependency Rules + +- Do NOT add new dependencies unless absolutely necessary +- Always pin exact versions (no `^` or `~`) +- Prefer existing libraries in the codebase + +--- + +## TypeScript Guidelines + +- Avoid using `any` unless absolutely unavoidable +- Prefer strict typing and inference +- Define reusable types in shared locations + +--- + +## Code Quality & Formatting + +After writing code: + +```bash +cd ui && npm run format +```` + +Then verify build: + +```bash +cd ui && npm run build +``` + +* Code must pass formatting and build checks +* Follow consistent naming and structure conventions + +--- + +## Anti-Patterns to Avoid + +* Duplicate components without considering reuse +* Mixing multiple state management approaches unnecessarily +* Overusing Redux +* Using unstable hook dependencies +* Adding heavy libraries for simple use cases +* Poorly structured or deeply nested JSX + +--- + +## Summary + +* Prioritize **reusability, performance, and consistency** +* Follow **strict folder structure and routing conventions** +* Use **the right tool for the right problem** +* Keep code **simple, predictable, and maintainable** diff --git a/Makefile b/Makefile index 9a0df1863b..84f2d66b28 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,17 @@ CYAN=\033[0;36m NC=\033[0m # No Color ECHO := printf '%b\n' -.PHONY: all help dev build-ui build build-cli run run-cli install-air clean test test-cli install-ui setup-workspace work-init work-clean docs docker-image docker-run cleanup-enterprise mod-tidy test-integrations-py test-integrations-ts install-playwright run-e2e run-e2e-ui run-e2e-headed +# nvm requires bash-compatible shell semantics; /bin/sh is dash on some Linux distros. +SHELL := /bin/bash + +# Ensures the Node version pinned in .nvmrc is active before any npm/node call. +# nvm is a shell function, so each recipe that needs it must inline this snippet +# via `$(USE_NODE); `. +USE_NODE = NVM_SH="$${NVM_DIR:-$$HOME/.nvm}/nvm.sh"; \ + [ -s "$$NVM_SH" ] || NVM_SH="$$(brew --prefix nvm 2>/dev/null)/nvm.sh"; \ + if [ -s "$$NVM_SH" ]; then . "$$NVM_SH" >/dev/null && nvm install >/dev/null 2>&1 && nvm use >/dev/null 2>&1; fi + +.PHONY: all help dev dev-pulse build-ui build build-cli run run-cli install-air install-pulse clean test test-cli install-ui setup-workspace work-init work-clean docs docker-image docker-run cleanup-enterprise mod-tidy test-integrations-py test-integrations-ts install-playwright run-e2e run-e2e-ui run-e2e-headed all: help @@ -61,17 +71,26 @@ cleanup-enterprise: ## Clean up enterprise directories if present @$(ECHO) "$(GREEN)Enterprise cleaned up$(NC)" install-ui: cleanup-enterprise - @which node > /dev/null || ($(ECHO) "$(RED)Error: Node.js is not installed. Please install Node.js first.$(NC)" && exit 1) - @which npm > /dev/null || ($(ECHO) "$(RED)Error: npm is not installed. Please install npm first.$(NC)" && exit 1) - @$(ECHO) "$(GREEN)Node.js and npm are installed$(NC)" - @cd ui && npm ci - @which next > /dev/null || ($(ECHO) "$(YELLOW)Installing nextjs...$(NC)" && npm install -g next) + @$(USE_NODE); \ + which node > /dev/null || ($(ECHO) "$(RED)Error: Node.js is not installed. Please install Node.js first.$(NC)" && exit 1); \ + which npm > /dev/null || ($(ECHO) "$(RED)Error: npm is not installed. Please install npm first.$(NC)" && exit 1); \ + $(ECHO) "$(GREEN)Node.js $$(node -v) and npm $$(npm -v) are installed$(NC)"; \ + if [ ! -d "ui/node_modules" ] || [ "ui/package.json" -nt "ui/node_modules/.package-lock.json" ] || [ "ui/package-lock.json" -nt "ui/node_modules/.package-lock.json" ]; then \ + $(ECHO) "$(YELLOW)Dependencies changed, running npm ci...$(NC)"; \ + cd ui && npm ci; \ + else \ + $(ECHO) "$(GREEN)UI dependencies up to date, skipping install$(NC)"; \ + fi @$(ECHO) "$(GREEN)UI deps are in sync$(NC)" install-air: ## Install air for hot reloading (if not already installed) @which air > /dev/null || ($(ECHO) "$(YELLOW)Installing air for hot reloading...$(NC)" && go install github.com/air-verse/air@latest) @$(ECHO) "$(GREEN)Air is ready$(NC)" +install-pulse: ## Install pulse for hot reloading (if not already installed) + @which pulse > /dev/null || ($(ECHO) "$(YELLOW)Installing pulse for hot reloading...$(NC)" && go install github.com/Pratham-Mishra04/pulse@latest) + @$(ECHO) "$(GREEN)Pulse is ready$(NC)" + install-delve: ## Install delve for debugging (if not already installed) @which dlv > /dev/null || ($(ECHO) "$(YELLOW)Installing delve for debugging...$(NC)" && go install github.com/go-delve/delve/cmd/dlv@latest) @$(ECHO) "$(GREEN)Delve is ready$(NC)" @@ -86,6 +105,7 @@ install-junit-viewer: ## Install junit-viewer for HTML report generation (if not $(ECHO) "$(GREEN)junit-viewer is already installed$(NC)"; \ else \ $(ECHO) "$(YELLOW)Installing junit-viewer for HTML reports...$(NC)"; \ + $(USE_NODE); \ if npm install -g junit-viewer 2>&1; then \ $(ECHO) "$(GREEN)junit-viewer installed successfully$(NC)"; \ else \ @@ -114,9 +134,9 @@ dev: install-ui install-air setup-workspace $(if $(DEBUG),install-delve) ## Star fi @$(ECHO) "" @$(ECHO) "$(YELLOW)Starting UI development server...$(NC)" - @if [ -n "$(DISABLE_PROFILER)" ]; then \ + @$(USE_NODE); if [ -n "$(DISABLE_PROFILER)" ]; then \ $(ECHO) "$(CYAN)DevProfiler disabled for testing$(NC)"; \ - cd ui && NEXT_PUBLIC_DISABLE_PROFILER=1 npm run dev & \ + cd ui && BIFROST_DISABLE_PROFILER=1 npm run dev & \ else \ cd ui && npm run dev & \ fi @@ -147,10 +167,59 @@ dev: install-ui install-air setup-workspace $(if $(DEBUG),install-delve) ## Star $(if $(APP_DIR),-app-dir "$(APP_DIR)"); \ fi +dev-pulse: install-ui install-pulse setup-workspace $(if $(DEBUG),install-delve) ## Start complete development environment using pulse for hot reloading + @$(ECHO) "$(GREEN)Starting Bifrost complete development environment (pulse)...$(NC)" + @$(ECHO) "$(YELLOW)This will start:$(NC)" + @$(ECHO) " 1. UI development server (localhost:3000)" + @$(ECHO) " 2. API server with UI proxy (localhost:$(PORT))" + @$(ECHO) "$(CYAN)Access everything at: http://localhost:$(PORT)$(NC)" + @if [ -n "$(DEBUG)" ]; then \ + $(ECHO) "$(CYAN) 3. Debugger (delve) listening on port 2345$(NC)"; \ + fi + @if [ ! -d "transports/bifrost-http/ui" ]; then \ + $(ECHO) "$(YELLOW)Creating transports/bifrost-http/ui directory...$(NC)"; \ + mkdir -p transports/bifrost-http/ui; \ + touch transports/bifrost-http/ui/.tmp; \ + fi + @$(ECHO) "" + @$(ECHO) "$(YELLOW)Starting UI development server...$(NC)" + @$(USE_NODE); if [ -n "$(DISABLE_PROFILER)" ]; then \ + $(ECHO) "$(CYAN)DevProfiler disabled for testing$(NC)"; \ + cd ui && BIFROST_DISABLE_PROFILER=1 npm run dev & \ + else \ + cd ui && npm run dev & \ + fi + @sleep 3 + @$(ECHO) "$(YELLOW)Starting API server with UI proxy...$(NC)" + @$(MAKE) setup-workspace >/dev/null + @if [ -f .env ]; then \ + $(ECHO) "$(YELLOW)Loading environment variables from .env...$(NC)"; \ + set -a; . ./.env; set +a; \ + fi; \ + if [ -n "$(DEBUG)" ]; then \ + $(ECHO) "$(CYAN)Starting with pulse + delve debugger on port 2345...$(NC)"; \ + $(ECHO) "$(YELLOW)Attach your debugger to localhost:2345$(NC)"; \ + BIFROST_UI_DEV=true pulse -- \ + -host "$(HOST)" \ + -port "$(PORT)" \ + -log-style "$(LOG_STYLE)" \ + -log-level "$(LOG_LEVEL)" \ + $(if $(PROMETHEUS_LABELS),-prometheus-labels "$(PROMETHEUS_LABELS)") \ + $(if $(APP_DIR),-app-dir "$(APP_DIR)"); \ + else \ + BIFROST_UI_DEV=true pulse -- \ + -host "$(HOST)" \ + -port "$(PORT)" \ + -log-style "$(LOG_STYLE)" \ + -log-level "$(LOG_LEVEL)" \ + $(if $(PROMETHEUS_LABELS),-prometheus-labels "$(PROMETHEUS_LABELS)") \ + $(if $(APP_DIR),-app-dir "$(APP_DIR)"); \ + fi + build-ui: install-ui ## Build ui @$(ECHO) "$(GREEN)Building ui...$(NC)" @rm -rf ui/.next - @cd ui && npm run build && npm run copy-build + @$(USE_NODE); cd ui && npm run build && npm run copy-build build: build-ui ## Build bifrost-http binary @if [ -n "$(LOCAL)" ]; then \ @@ -828,7 +897,8 @@ test-governance: install-gotestsum $(if $(DEBUG),install-delve) ## Run governanc setup-mcp-tests: ## Build all MCP test servers in examples/mcps/ (Go and TypeScript) @$(ECHO) "$(GREEN)Building MCP test servers...$(NC)" - @FAILED=0; \ + @$(USE_NODE); \ + FAILED=0; \ for mcp_dir in examples/mcps/*/; do \ if [ -d "$$mcp_dir" ]; then \ mcp_name=$$(basename $$mcp_dir); \ @@ -1200,6 +1270,7 @@ test-integrations-ts: ## Run TypeScript integration tests (Usage: make test-inte done; \ fi; \ TEST_FAILED=0; \ + $(USE_NODE); \ if ! which npm > /dev/null 2>&1; then \ $(ECHO) "$(RED)Error: npm not found$(NC)"; \ $(ECHO) "$(YELLOW)Install Node.js: https://nodejs.org/$(NC)"; \ @@ -1258,7 +1329,7 @@ install-playwright: ## Install Playwright test dependencies @$(ECHO) "$(GREEN)Installing Playwright dependencies...$(NC)" @which node > /dev/null || ($(ECHO) "$(RED)Error: Node.js is not installed. Please install Node.js first.$(NC)" && exit 1) @which npm > /dev/null || ($(ECHO) "$(RED)Error: npm is not installed. Please install npm first.$(NC)" && exit 1) - @cd tests/e2e && npm ci + @$(USE_NODE); cd tests/e2e && npm ci @cd tests/e2e && if npx playwright install --list 2>/dev/null | grep -q "chromium"; then \ $(ECHO) "$(CYAN)Chromium is already installed, skipping download$(NC)"; \ else \ diff --git a/core/bifrost.go b/core/bifrost.go index 70454a96e6..5d21e4dee8 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5,8 +5,8 @@ package bifrost import ( "context" + "errors" "fmt" - "math/rand" "slices" "sort" "strings" @@ -17,6 +17,7 @@ import ( "github.com/bytedance/sonic" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/keyselectors" "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/mcp/codemode/starlark" "github.com/maximhq/bifrost/core/providers/anthropic" @@ -174,6 +175,9 @@ type PluginPipeline struct { postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) chunkCount int + + // Plugin logging: cached scoped contexts for streaming post-hooks (reused across chunks) + streamScopedCtxs map[string]*schemas.BifrostContext } // pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks @@ -242,7 +246,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { bifrost.dropExcessRequests.Store(config.DropExcessRequests) if bifrost.keySelector == nil { - bifrost.keySelector = WeightedRandomKeySelector + bifrost.keySelector = keyselectors.WeightedRandom } // Initialize object pools @@ -613,9 +617,10 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * Message: "prompt not provided for text completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -628,7 +633,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TextCompletionResponse, nil } @@ -652,9 +657,10 @@ func (bifrost *Bifrost) TextCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "text not provided for text completion stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -683,9 +689,10 @@ func (bifrost *Bifrost) makeChatCompletionRequest(ctx *schemas.BifrostContext, r Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -748,9 +755,10 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -784,9 +792,10 @@ func (bifrost *Bifrost) makeResponsesRequest(ctx *schemas.BifrostContext, req *s Message: "responses not provided for responses request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -852,9 +861,10 @@ func (bifrost *Bifrost) ResponsesStreamRequest(ctx *schemas.BifrostContext, req Message: "responses not provided for responses stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -887,9 +897,10 @@ func (bifrost *Bifrost) CountTokensRequest(ctx *schemas.BifrostContext, req *sch Message: "input not provided for count tokens request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.CountTokensRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.CountTokensRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -919,16 +930,19 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem }, } } - if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !isLargePayloadPassthrough(ctx) { + hasExtraInputs := req.Params != nil && req.Params.ExtraParams != nil && + (req.Params.ExtraParams["inputs"] != nil || req.Params.ExtraParams["images"] != nil) + if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !hasExtraInputs && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "embedding input not provided for embedding request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.EmbeddingRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -941,7 +955,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.EmbeddingResponse, nil } @@ -965,9 +979,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "query not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -978,9 +993,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "documents not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -992,9 +1008,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: fmt.Sprintf("document text is empty at index %d", i), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1030,9 +1047,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document type not provided for ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1043,9 +1061,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document_url not provided for document_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1056,9 +1075,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "image_url not provided for image_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1093,9 +1113,10 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. Message: "speech input not provided for speech request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1108,7 +1129,7 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.SpeechResponse, nil } @@ -1132,9 +1153,10 @@ func (bifrost *Bifrost) SpeechStreamRequest(ctx *schemas.BifrostContext, req *sc Message: "speech input not provided for speech stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1166,9 +1188,10 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s Message: "transcription input not provided for transcription request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1181,7 +1204,7 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TranscriptionResponse, nil } @@ -1205,9 +1228,10 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, Message: "transcription input not provided for transcription stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1221,7 +1245,8 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, // ImageGenerationRequest sends an image generation request to the specified provider. func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1240,9 +1265,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for image generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1262,9 +1288,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1274,7 +1301,8 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, // ImageGenerationStreamRequest sends an image generation stream request to the specified provider. func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1293,9 +1321,10 @@ func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext Message: "prompt not provided for image generation stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1327,14 +1356,19 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "images not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } - // Prompt is not required when type is background_removal - if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") && + // Prompt is not required for certain operation types that work without a text prompt + var imageEditParamsType *string + if req.Params != nil { + imageEditParamsType = req.Params.Type + } + if !isPromptOptionalImageEditType(imageEditParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1342,9 +1376,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "prompt not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1365,9 +1400,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1395,14 +1431,19 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "images not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } - // Prompt is not required when type is background_removal - if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") && + // Prompt is not required for certain operation types that work without a text prompt + var imageEditStreamParamsType *string + if req.Params != nil { + imageEditStreamParamsType = req.Params.Type + } + if !isPromptOptionalImageEditType(imageEditStreamParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1410,9 +1451,10 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "prompt not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1444,9 +1486,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "image not provided for image variation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1467,9 +1510,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1479,7 +1523,8 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * // VideoGenerationRequest sends a video generation request to the specified provider. func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostVideoGenerationRequest, +) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1498,9 +1543,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for video generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1520,9 +1566,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -3483,7 +3530,7 @@ func (bifrost *Bifrost) removeProviderFromSlice(providerKey schemas.ModelProvide // }, toolSchema) func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RegisterTool(name, description, handler, toolSchema) @@ -3501,7 +3548,7 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // - error: Any retrieval error func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { if bifrost.MCPManager == nil { - return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") + return nil, fmt.Errorf("mcp is not configured in this bifrost instance") } clients := bifrost.MCPManager.GetClients() @@ -3541,7 +3588,7 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // // Returns: // - []schemas.ChatTool: List of available tools -func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { +func (bifrost *Bifrost) GetAvailableMCPTools(ctx *schemas.BifrostContext) []schemas.ChatTool { if bifrost.MCPManager == nil { return nil } @@ -3612,7 +3659,7 @@ func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error { // } func (bifrost *Bifrost) RemoveMCPClient(id string) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RemoveClient(id) @@ -3620,11 +3667,31 @@ func (bifrost *Bifrost) RemoveMCPClient(id string) error { // SetMCPManager sets the MCP manager for this Bifrost instance. // This allows injecting a custom MCP manager implementation (e.g., for enterprise features). +// If the provided manager is a concrete *mcp.MCPManager, Bifrost's plugin pipeline is injected +// into the manager's CodeMode so that nested tool calls run through the plugin hooks. // // Parameters: // - manager: The MCP manager to set (must implement MCPManagerInterface) func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { bifrost.MCPManager = manager + // Inject Bifrost's plugin pipeline into the manager's CodeMode so that + // nested tool calls (e.g. via Starlark executeCode) run through plugin hooks. + if m, ok := manager.(*mcp.MCPManager); ok { + m.SetPluginPipeline( + func() mcp.PluginPipeline { + pipeline := bifrost.getPluginPipeline() + if pp, ok := any(pipeline).(mcp.PluginPipeline); ok { + return pp + } + return nil + }, + func(pipeline mcp.PluginPipeline) { + if pp, ok := pipeline.(*PluginPipeline); ok { + bifrost.releasePluginPipeline(pp) + } + }, + ) + } } // UpdateMCPClient updates the MCP client. @@ -3645,7 +3712,7 @@ func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { // }) func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.UpdateClient(id, updatedConfig) @@ -3660,23 +3727,63 @@ func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPCli // - error: Any reconnection error func (bifrost *Bifrost) ReconnectMCPClient(id string) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.ReconnectClient(id) } +// VerifyPerUserOAuthConnection delegates to the MCP manager to verify an MCP +// server using a temporary access token and discover available tools. The +// connection is closed after verification. If the MCP manager is not yet +// initialized, it is lazily created (same as AddMCPClient). +func (bifrost *Bifrost) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { + // Ensure MCP manager is initialized (lazy init, same pattern as AddMCPClient) + if bifrost.MCPManager == nil { + bifrost.mcpInitOnce.Do(func() { + mcpConfig := schemas.MCPConfig{ + ClientConfigs: []*schemas.MCPClientConfig{}, + } + mcpConfig.PluginPipelineProvider = func() interface{} { + return bifrost.getPluginPipeline() + } + mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) { + if pp, ok := pipeline.(*PluginPipeline); ok { + bifrost.releasePluginPipeline(pp) + } + } + codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger) + bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) + }) + } + if bifrost.MCPManager == nil { + return nil, nil, fmt.Errorf("MCP manager is not initialized") + } + return bifrost.MCPManager.VerifyPerUserOAuthConnection(ctx, config, accessToken) +} + +// SetClientTools delegates to the MCP manager to update the tool map for an +// existing MCP client. +func (bifrost *Bifrost) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { + if bifrost.MCPManager != nil { + bifrost.MCPManager.SetClientTools(clientID, tools, toolNameMapping) + } +} + // UpdateToolManagerConfig updates the tool manager config for the MCP manager. // This allows for hot-reloading of the tool manager config at runtime. -func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { +// Pass the current value of disableAutoToolInject whenever only other fields +// change so the flag is never silently reset to its zero value. +func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } bifrost.MCPManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ - MaxAgentDepth: maxAgentDepth, - ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, - CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + MaxAgentDepth: maxAgentDepth, + ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + DisableAutoToolInject: disableAutoToolInject, }) return nil } @@ -3858,9 +3965,10 @@ func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) sche return bifrost.getProviderByKey(providerKey) } -// SelectKeyForProvider selects an API key for the given provider and model. -// Used by WebSocket handlers that need a key for upstream connections. -func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { +// SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model. +// Used by WebSocket handlers that need a key for upstream connections while honoring request-specific +// AllowedRequests gates such as realtime-only support. +func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { if ctx == nil { ctx = bifrost.ctx } @@ -3869,7 +3977,17 @@ func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, provid config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" { baseProvider = config.CustomProviderConfig.BaseProviderType } - return bifrost.selectKeyFromProviderForModel(ctx, schemas.WebSocketResponsesRequest, providerKey, model, baseProvider) + supportedKeys, _, err := bifrost.selectKeyFromProviderForModelWithPool(ctx, requestType, providerKey, model, baseProvider) + if err != nil { + return schemas.Key{}, err + } + if len(supportedKeys) == 0 { + return schemas.Key{}, nil + } + if len(supportedKeys) == 1 { + return supportedKeys[0], nil + } + return bifrost.keySelector(ctx, supportedKeys, providerKey, model) } // WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks. @@ -3883,6 +4001,13 @@ type WSStreamHooks struct { ShortCircuitResponse *schemas.BifrostResponse } +// RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a +// single realtime turn rather than one long-lived transport connection. +type RealtimeTurnHooks struct { + PostHookRunner schemas.PostHookRunner + Cleanup func() +} + // RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks, // and returns a PostHookRunner for per-chunk post-processing. // Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks. @@ -3921,12 +4046,22 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() - return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + return nil, bifrostErr } if shortCircuit != nil { if shortCircuit.Error != nil { _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3935,6 +4070,10 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3946,8 +4085,21 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } } + wsProvider, wsModel, _ := preReq.GetRequestFields() postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - return pipeline.RunPostLLMHooks(ctx, result, err, preCount) + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + if result != nil { + result.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + return resp, bifrostErr } return &WSStreamHooks{ @@ -3956,6 +4108,94 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche }, nil } +// RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for +// a single realtime turn. Unlike generic stream hooks, realtime turns do not +// support short-circuit responses in v1 because the transports cannot yet emit a +// fully synthetic assistant turn without an upstream generation. +func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) { + if req == nil { + bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil") + bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest + return nil, bifrostErr + } + if ctx == nil { + ctx = bifrost.ctx + } + + if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { + ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) + } + + tracer := bifrost.getTracer() + ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) + + if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { + traceID := tracer.CreateTrace("") + if traceID != "" { + ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) + } + } + + pipeline := bifrost.getPluginPipeline() + cleanup := func() { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.CleanupStreamAccumulator(traceID) + } + bifrost.releasePluginPipeline(pipeline) + } + provider, model, _ := req.GetRequestFields() + + preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) + if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + if shortCircuit != nil { + if shortCircuit.Error != nil { + shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + if bifrostErr != nil { + return nil, bifrostErr + } + return nil, shortCircuit.Error + } + if shortCircuit.Response != nil { + // Short-circuit responses are not supported for realtime turns (v1). + // Treat this like an error turn so plugins can close pending state cleanly. + bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + } + + return &RealtimeTurnHooks{ + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + drainAndAttachPluginLogs(ctx) + return resp, bifrostErr + }, + Cleanup: cleanup, + }, nil +} + // getProviderByKey retrieves a provider instance from the providers array by its provider key. // Returns the provider if found, or nil if no provider with the given key exists. func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { @@ -4158,11 +4398,7 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) return nil, err } @@ -4195,16 +4431,6 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4247,29 +4473,10 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4284,11 +4491,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) err.StatusCode = schemas.Ptr(fasthttp.StatusBadRequest) return nil, err } @@ -4310,16 +4513,6 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4360,29 +4553,10 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4394,11 +4568,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4409,7 +4579,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling requestHandler, so streaming goroutines @@ -4426,7 +4598,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif // Handle short-circuit with response (success case) if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4434,7 +4608,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif // Handle short-circuit with error if shortCircuit.Error != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4442,11 +4618,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4487,36 +4659,26 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4525,15 +4687,13 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4543,33 +4703,52 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif select { case result = <-msg.Response: resp, bifrostErr := pipeline.RunPostLLMHooks(msg.Context, result, nil, pluginCount) + drainAndAttachPluginLogs(msg.Context) if bifrostErr != nil { bifrost.releaseChannelMessage(msg) return nil, bifrostErr } bifrost.releaseChannelMessage(msg) - // Checking if need to drop raw messages - // This we use for requests like containers, container files, skills etc. - if drop, ok := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool); ok && drop && resp != nil { - extraField := resp.GetExtraFields() - extraField.RawRequest = nil - extraField.RawResponse = nil + // Strip raw fields that were captured for logging but should not reach the client. + if resp != nil { + dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) + dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) + if dropReq || dropResp { + extraField := resp.GetExtraFields() + if dropReq { + extraField.RawRequest = nil + } + if dropResp { + extraField.RawResponse = nil + } + } } return resp, nil case bifrostErrVal := <-msg.Err: bifrostErrPtr := &bifrostErrVal resp, bifrostErrPtr = pipeline.RunPostLLMHooks(msg.Context, nil, bifrostErrPtr, pluginCount) + drainAndAttachPluginLogs(msg.Context) bifrost.releaseChannelMessage(msg) - // Drop raw request/response on error path too - if drop, ok := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool); ok && drop { + // Strip raw fields on error path too. + dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) + dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) + if dropReq || dropResp { if bifrostErrPtr != nil { - bifrostErrPtr.ExtraFields.RawRequest = nil - bifrostErrPtr.ExtraFields.RawResponse = nil + if dropReq { + bifrostErrPtr.ExtraFields.RawRequest = nil + } + if dropResp { + bifrostErrPtr.ExtraFields.RawResponse = nil + } } if resp != nil { extraField := resp.GetExtraFields() - extraField.RawRequest = nil - extraField.RawResponse = nil + if dropReq { + extraField.RawRequest = nil + } + if dropResp { + extraField.RawResponse = nil + } } } if bifrostErrPtr != nil { @@ -4585,7 +4764,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif // pool and is GC'd. That is intentional: a small pool leak on cancellation // is far safer than corrupting another request's channels. provider, model, _ := req.GetRequestFields() - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "waiting for provider response") + bifrostErr := newBifrostCtxDoneError(ctx, "waiting for provider response") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4596,11 +4777,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4611,7 +4788,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling RunLLMPreHooks, so plugins and streaming goroutines @@ -4632,14 +4811,21 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } pipeline := bifrost.getPluginPipeline() - defer bifrost.releasePluginPipeline(pipeline) + releasePipeline := true + defer func() { + if releasePipeline { + bifrost.releasePluginPipeline(pipeline) + } + }() preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4647,13 +4833,23 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Handle short-circuit with stream if shortCircuit.Stream != nil { outputStream := make(chan *schemas.BifrostStreamChunk) + releasePipeline = false // pipeline is released inside the goroutine after stream drains // Create a post hook runner cause pipeline object is put back in the pool on defer pipelinePostHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - return pipeline.RunPostLLMHooks(ctx, result, err, preCount) + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + return resp, bifrostErr } go func() { + defer func() { + drainAndAttachPluginLogs(ctx) // ensure logs are drained even if stream closes without a final chunk + pipeline.FinalizeStreamingPostHookSpans(ctx) + bifrost.releasePluginPipeline(pipeline) + }() defer close(outputStream) for streamMsg := range shortCircuit.Stream { @@ -4700,7 +4896,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem return } - //TODO: Release the processed response immediately after use + // TODO: Release the processed response immediately after use } }() @@ -4709,7 +4905,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Handle short-circuit with error if shortCircuit.Error != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4717,11 +4915,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4762,36 +4956,26 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4800,15 +4984,13 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4826,6 +5008,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) // On error we will complete post-hooks recoveredResp, recoveredErr := pipeline.RunPostLLMHooks(ctx, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load())) + drainAndAttachPluginLogs(ctx) bifrost.releaseChannelMessage(msg) if recoveredErr != nil { return nil, recoveredErr @@ -4842,13 +5025,20 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } } -// executeRequestWithRetries is a generic function that handles common request processing logic -// It consolidates retry logic, backoff calculation, and error handling -// It is not a bifrost method because interface methods in go cannot be generic +// executeRequestWithRetries is a generic function that handles common request processing logic. +// It consolidates retry logic, backoff calculation, error handling, and key rotation. +// It is not a bifrost method because interface methods in go cannot be generic. +// +// keyProvider, when non-nil, is called on the first attempt and again whenever a rate-limit error +// triggers a key rotation. It receives the set of key IDs already used in the current rotation +// cycle so it can exclude them; when the pool is exhausted the provider resets the set and starts +// a fresh weighted round. Network errors (5xx) reuse the same key since they are transient server +// issues rather than per-key capacity problems. func executeRequestWithRetries[T any]( ctx *schemas.BifrostContext, config *schemas.ProviderConfig, - requestHandler func() (T, *schemas.BifrostError), + requestHandler func(key schemas.Key) (T, *schemas.BifrostError), + keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error), requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, @@ -4859,8 +5049,77 @@ func executeRequestWithRetries[T any]( var bifrostError *schemas.BifrostError var attempts int + var currentKey schemas.Key + var usedKeyIDs map[string]bool + lastWasRateLimit := false + for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { ctx.SetValue(schemas.BifrostContextKeyNumberOfRetries, attempts) + + // Reset the trail on the first attempt so a reused or shared context (bifrost.ctx) + // doesn't carry over records from a previous request. + if keyProvider != nil && attempts == 0 { + ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, []schemas.KeyAttemptRecord{}) + } + + // Select / rotate key: always on attempt 0, and again when the previous failure was a + // rate-limit (different key may have remaining capacity). Network errors keep the same key. + if keyProvider != nil && (attempts == 0 || lastWasRateLimit) { + if usedKeyIDs == nil { + usedKeyIDs = make(map[string]bool) + } + + // Wrap key selection in a dedicated span so traces show which key was chosen + // (and when rotation happened). The span is opened before keyProvider is called + // so selection errors are captured too. + keyTracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) + var keySpanCtx context.Context + var keyHandle schemas.SpanHandle + if keyTracer != nil { + keySpanCtx, keyHandle = keyTracer.StartSpan(ctx, "key.selection", schemas.SpanKindInternal) + keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(providerKey)) + keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model) + if attempts > 0 { + keyTracer.SetAttribute(keyHandle, "retry.count", attempts) + } + } + + selectedKey, err := keyProvider(usedKeyIDs) + + if keyTracer != nil { + if err != nil { + keyTracer.SetAttribute(keyHandle, "error", err.Error()) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error()) + } else { + keyTracer.SetAttribute(keyHandle, "key.id", selectedKey.ID) + keyTracer.SetAttribute(keyHandle, "key.name", selectedKey.Name) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "") + // Propagate the span context so subsequent spans (llm.call / retry.attempt.N) + // are correctly linked in the trace hierarchy. + ctx.SetValue(schemas.BifrostContextKeySpanID, keySpanCtx.Value(schemas.BifrostContextKeySpanID)) + } + } + + if err != nil { + var zero T + return zero, newBifrostErrorFromMsg(err.Error()) + } + currentKey = selectedKey + ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, currentKey.ID) + ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, currentKey.Name) + } + + // Append a trail record for every attempt (key rotation and same-key retries alike). + // Skipped when keyProvider is nil (keyless providers have no key to track). + // FailReason is populated below once the attempt outcome is known. + if keyProvider != nil { + schemas.AppendToContextList(ctx, schemas.BifrostContextKeyAttemptTrail, schemas.KeyAttemptRecord{ + Attempt: attempts, + KeyID: currentKey.ID, + KeyName: currentKey.Name, + }) + } + if attempts > 0 { // Log retry attempt var retryMsg string @@ -4952,7 +5211,7 @@ func executeRequestWithRetries[T any]( } // Attempt the request - result, bifrostError = requestHandler() + result, bifrostError = requestHandler(currentKey) // For streaming requests that returned success, check if the first chunk // is actually an error (e.g., rate limits sent as SSE events in HTTP 200). @@ -4988,7 +5247,7 @@ func executeRequestWithRetries[T any]( } else { // Populate LLM response attributes for non-streaming responses if resp, ok := any(result).(*schemas.BifrostResponse); ok { - tracer.PopulateLLMResponseAttributes(handle, resp, bifrostError) + tracer.PopulateLLMResponseAttributes(ctx, handle, resp, bifrostError) } // End span with appropriate status @@ -5016,25 +5275,50 @@ func executeRequestWithRetries[T any]( // Check if we should retry based on status code or error message shouldRetry := false - - if bifrostError.Error != nil && (bifrostError.Error.Message == schemas.ErrProviderDoRequest || bifrostError.Error.Message == schemas.ErrProviderNetworkError) { - shouldRetry = true - logger.Debug("detected request HTTP/network error, will retry: %s", bifrostError.Error.Message) - } - - // Retry if status code or error object indicates rate limiting - if (bifrostError.StatusCode != nil && retryableStatusCodes[*bifrostError.StatusCode]) || + isRateLimit := (bifrostError.StatusCode != nil && *bifrostError.StatusCode == 429) || (bifrostError.Error != nil && (IsRateLimitErrorMessage(bifrostError.Error.Message) || (bifrostError.Error.Type != nil && IsRateLimitErrorMessage(*bifrostError.Error.Type)) || - (bifrostError.Error.Code != nil && IsRateLimitErrorMessage(*bifrostError.Error.Code)))) { + (bifrostError.Error.Code != nil && IsRateLimitErrorMessage(*bifrostError.Error.Code)))) + + errMessage := GetErrorMessage(bifrostError) + + if bifrostError.Error != nil && + (bifrostError.Error.Message == schemas.ErrProviderDoRequest || + bifrostError.Error.Message == schemas.ErrProviderNetworkError) { + shouldRetry = true + logger.Debug("detected request HTTP/network error, will retry: %s", errMessage) + } else if (bifrostError.StatusCode != nil && retryableStatusCodes[*bifrostError.StatusCode]) || isRateLimit { shouldRetry = true - logger.Debug("detected rate limit error in message, will retry: %s", bifrostError.Error.Message) + logger.Debug("encountered error that should be retried: %s", errMessage) + } + + // Fill FailReason on any failed attempt (retryable or terminal). + // Use the provider error type when present; fall back to "unknown". + if trail, ok := ctx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok && len(trail) > 0 { + reason := "unknown" + if bifrostError.Error != nil && bifrostError.Error.Type != nil && *bifrostError.Error.Type != "" { + reason = *bifrostError.Error.Type + } else if isRateLimit { + reason = "rate_limit_error" + } + trail[len(trail)-1].FailReason = &reason + ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, trail) } if !shouldRetry { break } + + // Mark current key as used so the next selection excludes it (rate-limit only). + // Network errors keep the same key — they are transient server issues, not per-key. + if isRateLimit && keyProvider != nil { + if usedKeyIDs == nil { + usedKeyIDs = make(map[string]bool) + } + usedKeyIDs[currentKey.ID] = true + } + lastWasRateLimit = isRateLimit } // Add retry information to error @@ -5042,6 +5326,13 @@ func executeRequestWithRetries[T any]( logger.Debug("request failed after %d %s", attempts, map[bool]string{true: "attempts", false: "attempt"}[attempts > 1]) } + // On final error, clear selected_key so it only reflects a key that actually served a successful response. + // The attempt trail is the authoritative record of which keys were tried. + if bifrostError != nil && keyProvider != nil { + ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, "") + ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, "") + } + return result, bifrostError } @@ -5102,30 +5393,64 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas req.Context.SetValue(schemas.BifrostContextKeyIsCustomProvider, !IsStandardProvider(baseProvider)) // Determine whether this provider attempt should capture raw payloads. - // logging-only mode (store_raw_request_response=true, send_back_raw_*=false): - // sets BifrostContextKeySendBackRaw* = true so providers capture via the unified - // ShouldSendBackRaw* path, and sets BifrostContextKeyRawRequestResponseForLogging - // so the payload is stripped before the response reaches the client. - // full send-back mode (send_back_raw_request/response=true): - // BifrostContextKeySendBackRaw* are set as before; stripping flag stays false. - // Always set both flags explicitly so stale values from a previous provider - // attempt (e.g. first attempt was logging-only, fallback is full send-back) - // cannot leak into the new attempt on a reused context. - existingSendBackReq, _ := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool) - existingSendBackResp, _ := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool) - loggingOnly := config.StoreRawRequestResponse && - !config.SendBackRawRequest && !existingSendBackReq && - !config.SendBackRawResponse && !existingSendBackResp - req.Context.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, loggingOnly) - if loggingOnly { - // Enable capture via the standard flags so ShouldSendBackRaw* needs only one check. - req.Context.SetValue(schemas.BifrostContextKeySendBackRawRequest, true) - req.Context.SetValue(schemas.BifrostContextKeySendBackRawResponse, true) - } - - key := schemas.Key{} + // + // Effective values are computed by merging provider config with any per-request + // context overrides (BifrostContextKeySendBackRawRequest/Response and + // BifrostContextKeyStoreRawRequestResponse). A context value set to either true + // or false fully overrides the provider config for that flag. + // + // Each flag is independent: + // send_back_raw_request — include raw request bytes in the client response. + // send_back_raw_response — include raw response bytes in the client response. + // store_raw_request_response — persist raw bytes in log records (logging plugin only). + // + // Capture is enabled per-side whenever send-back OR store is requested for that side. + // Strip flags tell the response path to remove that side's bytes before the payload + // reaches the caller (used when store=true but send-back=false for that side). + // + // All internal signals are always written explicitly on every attempt so stale values + // from a previous provider attempt (e.g. different fallback provider config) cannot + // leak into the new attempt on a reused context. The user override keys + // (BifrostContextKeySendBackRaw*, BifrostContextKeyStoreRawRequestResponse) are + // never overwritten — they are read-only from bifrost.go's perspective. + + // Step 1: compute effective value for each flag (provider config ← per-request override). + effectiveSendBackReq := config.SendBackRawRequest + if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok { + effectiveSendBackReq = override + } + effectiveSendBackResp := config.SendBackRawResponse + if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok { + effectiveSendBackResp = override + } + effectiveStore := config.StoreRawRequestResponse + if override, ok := req.Context.Value(schemas.BifrostContextKeyStoreRawRequestResponse).(bool); ok { + effectiveStore = override + } + + // Step 2: derive per-side capture and strip flags. + // Capture if we need to send the data back OR store it — independent per side. + captureReq := effectiveSendBackReq || effectiveStore + captureResp := effectiveSendBackResp || effectiveStore + // Strip from client response if we captured for storage but not for send-back. + dropReq := effectiveStore && !effectiveSendBackReq + dropResp := effectiveStore && !effectiveSendBackResp + + // Step 3: write all internal signals explicitly (never touch the user override keys). + req.Context.SetValue(schemas.BifrostContextKeyCaptureRawRequest, captureReq) + req.Context.SetValue(schemas.BifrostContextKeyCaptureRawResponse, captureResp) + req.Context.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, dropReq) + req.Context.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, dropResp) + // Tells the logging plugin whether to persist raw bytes in log records. + req.Context.SetValue(schemas.BifrostContextKeyShouldStoreRawInLogs, effectiveStore) + var keys []schemas.Key - if providerRequiresKey(baseProvider, config.CustomProviderConfig) { + // keyProvider is passed to executeRequestWithRetries to manage key selection and rotation. + // It is nil when no key is required (e.g. providerRequiresKey=false) or for multi-key + // batch/file/container operations that manage their own key lists. + var keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error) + + if providerRequiresKey(config.CustomProviderConfig) { // ListModels needs all enabled/supported keys so providers can aggregate // and report per-key statuses (KeyStatuses). if req.RequestType == schemas.ListModelsRequest { @@ -5139,9 +5464,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -5168,59 +5494,103 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue } } else { - // Use the custom provider name for actual key selection, but pass base provider type for key validation - // Start span for key selection - keyTracer := bifrost.getTracer() - keySpanCtx, keyHandle := keyTracer.StartSpan(req.Context, "key.selection", schemas.SpanKindInternal) - keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(provider.GetProviderKey())) - keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model) - - key, err = bifrost.selectKeyFromProviderForModel(req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) - if err != nil { - keyTracer.SetAttribute(keyHandle, "error", err.Error()) - keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error()) - bifrost.logger.Debug("error selecting key for model %s: %v", model, err) + // Build the key pool for this request. Selection and rotation are deferred to + // executeRequestWithRetries via keyProvider so that each retry attempt can use + // a different key (on rate-limit errors) without re-running the full filtering. + supportedKeys, canRotate, keyPoolErr := bifrost.selectKeyFromProviderForModelWithPool(req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) + if keyPoolErr != nil { + bifrost.logger.Debug("error building key pool for model %s: %v", model, keyPoolErr) req.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ - Message: err.Error(), - Error: err, + Message: keyPoolErr.Error(), + Error: keyPoolErr, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue } - keyTracer.SetAttribute(keyHandle, "key.id", key.ID) - keyTracer.SetAttribute(keyHandle, "key.name", key.Name) - keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "") - // Update context with span ID for subsequent operations - req.Context.SetValue(schemas.BifrostContextKeySpanID, keySpanCtx.Value(schemas.BifrostContextKeySpanID)) - req.Context.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID) - req.Context.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name) + + if len(supportedKeys) == 0 { + // SkipKeySelection path — keyProvider stays nil, zero Key is used. + } else if !canRotate { + // Fixed key (DirectKey, explicit ID/name, session stickiness): always + // return the same key regardless of usedKeyIDs. + fixedKey := supportedKeys[0] + keyProvider = func(_ map[string]bool) (schemas.Key, error) { + return fixedKey, nil + } + } else { + // Rotating pool: weighted selection with per-cycle exclusion. + // Captures supportedKeys, bifrost.keySelector, provider/model by value. + pool := supportedKeys + provKey := provider.GetProviderKey() + mdl := model + keyProvider = func(usedKeyIDs map[string]bool) (schemas.Key, error) { + available := make([]schemas.Key, 0, len(pool)) + for _, k := range pool { + if !usedKeyIDs[k.ID] { + available = append(available, k) + } + } + if len(available) == 0 { + // All keys exhausted — start a fresh weighted round. + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + available = pool + } + return bifrost.keySelector(req.Context, available, provKey, mdl) + } + } } } } + + originalModelRequested := model + // resolvedModel is set inside the handler closures below on every attempt so that each + // key's own alias mapping is applied. postHookRunner captures resolvedModel by reference + // (Go closure semantics) and will therefore always see the value from the last attempt. + var resolvedModel string + // Create plugin pipeline for streaming requests outside retry loop to prevent leaks var postHookRunner schemas.PostHookRunner var pipeline *PluginPipeline if IsStreamRequestType(req.RequestType) { pipeline = bifrost.getPluginPipeline() postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + // resolvedModel is captured by reference and reflects the alias from the last attempt. + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) return nil, bifrostErr + } else if resp != nil { + resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) } return resp, nil } @@ -5238,15 +5608,21 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer) } - // Execute request with retries + // Execute request with retries. Each handler invocation resolves the alias for the key + // selected by keyProvider on that attempt and mutates the worker-local request model. + // resolvedModel (captured by reference in postHookRunner) is updated accordingly. if IsStreamRequestType(req.RequestType) { - stream, bifrostError = executeRequestWithRetries(req.Context, config, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - return bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner) - }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) + stream, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + resolvedModel = k.Aliases.Resolve(originalModelRequested) + req.SetModel(resolvedModel) + return bifrost.handleProviderStreamRequest(provider, req, k, postHookRunner) + }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } else { - result, bifrostError = executeRequestWithRetries(req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) { - return bifrost.handleProviderRequest(provider, config, req, key, keys) - }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) + result, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { + resolvedModel = k.Aliases.Resolve(originalModelRequested) + req.SetModel(resolvedModel) + return bifrost.handleProviderRequest(provider, config, req, k, keys) + }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } // Release pipeline immediately for non-streaming requests only @@ -5257,14 +5633,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } if bifrostError != nil { - bifrostError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, - RawRequest: bifrostError.ExtraFields.RawRequest, - RawResponse: bifrostError.ExtraFields.RawResponse, - KeyStatuses: bifrostError.ExtraFields.KeyStatuses, - } + bifrostError.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) // Send error with context awareness to prevent deadlock select { @@ -5278,6 +5647,9 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") } } else { + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } if IsStreamRequestType(req.RequestType) { // Send stream with context awareness to prevent deadlock select { @@ -5321,12 +5693,34 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config } response.ListModelsResponse = listModelsResponse case schemas.TextCompletionRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, chatRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.TextCompletionResponse = chatCompletionResponse.ToBifrostTextCompletionResponse() + break + } + } textCompletionResponse, bifrostError := provider.TextCompletion(req.Context, key, req.BifrostRequest.TextCompletionRequest) if bifrostError != nil { return nil, bifrostError } response.TextCompletionResponse = textCompletionResponse case schemas.ChatCompletionRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + responsesResponse, bifrostError := provider.Responses(req.Context, key, responsesRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.ChatResponse = responsesResponse.ToBifrostChatResponse() + break + } + } chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, req.BifrostRequest.ChatRequest) if bifrostError != nil { return nil, bifrostError @@ -5386,6 +5780,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config if bifrostError != nil { return nil, bifrostError } + transcriptionResponse.BackfillParams(req.BifrostRequest.TranscriptionRequest) response.TranscriptionResponse = transcriptionResponse case schemas.ImageGenerationRequest: imageResponse, bifrostError := provider.ImageGeneration(req.Context, key, req.BifrostRequest.ImageGenerationRequest) @@ -5579,9 +5974,10 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -5592,8 +5988,20 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { switch req.RequestType { case schemas.TextCompletionStreamRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), key, chatRequest) + } + } return provider.TextCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), key, responsesRequest) + } + } return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) case schemas.ResponsesStreamRequest: return provider.ResponsesStream(req.Context, postHookRunner, key, req.BifrostRequest.ResponsesRequest) @@ -5615,9 +6023,10 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -5639,7 +6048,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ - Message: "MCP is not configured in this Bifrost instance", + Message: "mcp is not configured in this bifrost instance", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: requestType, @@ -5664,6 +6073,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR // Handle short-circuit with response (success case) if shortCircuit.Response != nil { finalMcpResp, bifrostErr := pipeline.RunMCPPostHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } @@ -5673,6 +6083,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR if shortCircuit.Error != nil { // Capture post-hook results to respect transformations or recovery finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) // Return post-hook error if present (post-hook may have transformed the error) if finalErr != nil { return nil, finalErr @@ -5715,6 +6126,11 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR RequestType: requestType, }, } + // Preserve MCPUserOAuthRequiredError for downstream detection in agent mode + var oauthErr *schemas.MCPUserOAuthRequiredError + if errors.As(err, &oauthErr) { + bifrostErr.ExtraFields.MCPAuthRequired = oauthErr + } } else if result == nil { bifrostErr = &schemas.BifrostError{ IsBifrostError: false, @@ -5732,6 +6148,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR // Run post-hooks finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, mcpResp, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) if finalErr != nil { return nil, finalErr @@ -5767,6 +6184,9 @@ func (bifrost *Bifrost) executeMCPToolWithHooks(ctx *schemas.BifrostContext, req resp, bifrostErr := bifrost.handleMCPToolExecution(ctx, request, requestType) if bifrostErr != nil { + if bifrostErr.ExtraFields.MCPAuthRequired != nil { + return nil, bifrostErr.ExtraFields.MCPAuthRequired + } return nil, fmt.Errorf("%s", GetErrorMessage(bifrostErr)) } return resp, nil @@ -5796,7 +6216,9 @@ func (p *PluginPipeline) RunLLMPreHooks(ctx *schemas.BifrostContext, req *schema } } - req, shortCircuit, err = plugin.PreLLMHook(ctx, req) + pluginCtx := ctx.WithPluginScope(&pluginName) + req, shortCircuit, err = plugin.PreLLMHook(pluginCtx, req) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5836,8 +6258,10 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche if runFrom > len(p.llmPlugins) { runFrom = len(p.llmPlugins) } - // Detect streaming mode - if StreamStartTime is set, we're in a streaming context - isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil + requestType, _, _, _ := GetResponseFields(resp, bifrostErr) + // Realtime turns carry StreamStartTime for plugin latency/final-chunk context, + // but they are finalized as one completed turn, not chunk-by-chunk stream output. + isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() var err error @@ -5847,8 +6271,17 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche p.logger.Debug("running post-hook for plugin %s", pluginName) if isStreaming { // For streaming: accumulate timing, don't create individual spans per chunk + // Lazily create cached scoped contexts on first chunk (reused across all chunks) + if p.streamScopedCtxs == nil { + p.streamScopedCtxs = make(map[string]*schemas.BifrostContext, len(p.llmPlugins)) + for _, pl := range p.llmPlugins { + name := pl.GetName() + p.streamScopedCtxs[name] = ctx.WithPluginScope(&name) + } + } + pluginCtx := p.streamScopedCtxs[pluginName] start := time.Now() - resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr) + resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) duration := time.Since(start) p.accumulatePluginTiming(pluginName, duration, err != nil) @@ -5865,7 +6298,9 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) } } - resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr) + pluginCtx := ctx.WithPluginScope(&pluginName) + resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { p.tracer.SetAttribute(handle, "error", err.Error()) @@ -5919,7 +6354,9 @@ func (p *PluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schema } } - req, shortCircuit, err = plugin.PreMCPHook(ctx, req) + pluginCtx := ctx.WithPluginScope(&pluginName) + req, shortCircuit, err = plugin.PreMCPHook(pluginCtx, req) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5974,7 +6411,9 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s } } - mcpResp, bifrostErr, err = plugin.PostMCPHook(ctx, mcpResp, bifrostErr) + pluginCtx := ctx.WithPluginScope(&pluginName) + mcpResp, bifrostErr, err = plugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -6000,7 +6439,11 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s return mcpResp, nil } -// resetPluginPipeline resets a PluginPipeline instance for reuse +// resetPluginPipeline resets a PluginPipeline instance for reuse. +// IMPORTANT: drainAndAttachPluginLogs must be called on the root BifrostContext +// BEFORE this method, because it calls ReleasePluginScope on cached scoped contexts +// which nils out their pluginLogs pointer. The drain reads from the shared store +// on the root context, so it must happen while the store is still referenced. func (p *PluginPipeline) resetPluginPipeline() { p.executedPreHooks = 0 p.preHookErrors = p.preHookErrors[:0] @@ -6011,6 +6454,25 @@ func (p *PluginPipeline) resetPluginPipeline() { clear(p.postHookTimings) } p.postHookPluginOrder = p.postHookPluginOrder[:0] + // Release cached scoped contexts for streaming + for _, scopedCtx := range p.streamScopedCtxs { + scopedCtx.ReleasePluginScope() + } + p.streamScopedCtxs = nil +} + +// drainAndAttachPluginLogs drains accumulated plugin logs from the BifrostContext +// and attaches them to the trace for later retrieval by observability plugins. +func drainAndAttachPluginLogs(ctx *schemas.BifrostContext) { + tracer, traceID, err := GetTracerFromContext(ctx) + if err != nil || tracer == nil || traceID == "" { + return + } + logs := ctx.DrainPluginLogs() + if len(logs) == 0 { + return + } + tracer.AttachPluginLogs(traceID, logs) } // accumulatePluginTiming accumulates timing for a plugin during streaming @@ -6102,7 +6564,9 @@ func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline { return pipeline } -// releasePluginPipeline returns a PluginPipeline to the pool +// releasePluginPipeline returns a PluginPipeline to the pool. +// Caller must ensure drainAndAttachPluginLogs has already been called on the +// associated BifrostContext before calling this method. func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) { pipeline.resetPluginPipeline() bifrost.pluginPipelinePool.Put(pipeline) @@ -6312,17 +6776,21 @@ func (bifrost *Bifrost) getAllSupportedKeys(ctx *schemas.BifrostContext, provide // Filter keys for ListModels - only check if key has a value var supportedKeys []schemas.Key - for _, k := range keys { + for _, key := range keys { // Skip disabled keys (default enabled when nil) - if k.Enabled != nil && !*k.Enabled { + if key.Enabled != nil && !*key.Enabled { continue } - if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { - supportedKeys = append(supportedKeys, k) + if err := validateKey(baseProviderType, &key); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) + continue + } + if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { + supportedKeys = append(supportedKeys, key) } } - bifrost.logger.Debug("[Bifrost] Provider %s: %d enabled keys found", providerKey, len(supportedKeys)) + bifrost.logger.Debug("[Bifrost] Provider %s: %d valid keys found", providerKey, len(supportedKeys)) if len(supportedKeys) == 0 { return nil, fmt.Errorf("no valid keys found for provider: %v", providerKey) @@ -6365,17 +6833,21 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p continue } + if err := validateKey(baseProviderType, &k); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", k.Name, k.ID, providerKey, err.Error()) + continue + } + // Model filtering logic: // - If model is nil or empty → include all keys (no model filter) // - If model is specified: // - If model is in key.BlacklistedModels → exclude (wins over Models allow list) - // - If key.Models is empty → include key (supports all non-blacklisted models) + // - If key.Models is ["*"] → include key (supports all non-blacklisted models) + // - If key.Models is empty → exclude key (deny-by-default) // - If key.Models is non-empty → only include if model is in list + // Blacklist wins over allowlist if model != nil && *model != "" { - if len(k.BlacklistedModels) > 0 && slices.Contains(k.BlacklistedModels, *model) { - continue - } - if len(k.Models) > 0 && !slices.Contains(k.Models, *model) { + if k.BlacklistedModels.IsBlocked(*model) || !k.Models.IsAllowed(*model) { continue } } @@ -6405,28 +6877,39 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p return filteredKeys, nil } -// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. -// It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) { - // Check if key has been set in the context explicitly +// selectKeyFromProviderForModelWithPool returns the filtered pool of eligible keys for the given +// provider/model, along with a canRotate flag indicating whether key rotation across retries is +// permitted. Key selection (choosing which key to use) is deferred to executeRequestWithRetries +// via the keyProvider closure built by the caller. +// +// canRotate=false is returned for cases where the caller must always use the same key: +// - DirectKey (caller-supplied key bypasses all selection) +// - SkipKeySelection (provider allows keyless requests; empty slice returned) +// - Explicit BifrostContextKeyAPIKeyID / APIKeyName (user pinned a specific key) +// - Session stickiness (key persisted in KV store for the session lifetime) +// - Single-key pool (only one eligible key — rotation is a no-op, KV write skipped) +// +// canRotate=true is returned when there are two or more eligible keys and no pinning +// or stickiness constraint is in effect. +func (bifrost *Bifrost) selectKeyFromProviderForModelWithPool(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) ([]schemas.Key, bool, error) { + // DirectKey: caller supplied a key directly — no pool, no rotation. if ctx != nil { - key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) - if ok { - return key, nil + if key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key); ok { + return []schemas.Key{key}, false, nil } } - // Check if key skipping is allowed + // SkipKeySelection: provider allows keyless requests — return empty pool, no rotation. if skipKeySelection, ok := ctx.Value(schemas.BifrostContextKeySkipKeySelection).(bool); ok && skipKeySelection && isKeySkippingAllowed(providerKey) { - return schemas.Key{}, nil + return []schemas.Key{}, false, nil } + // Get keys for provider keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey) if err != nil { - return schemas.Key{}, err + return nil, false, err } - // Check if no keys found if len(keys) == 0 { - return schemas.Key{}, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model) + return nil, false, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model) } // For batch API operations, filter keys to only include those with UseForBatchAPI enabled @@ -6438,7 +6921,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex } } if len(batchEnabledKeys) == 0 { - return schemas.Key{}, fmt.Errorf("no config found for batch APIs. Please enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey) + return nil, false, fmt.Errorf("no config found for batch apis; enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey) } keys = batchEnabledKeys } @@ -6452,112 +6935,89 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex skipModelCheck := (model == "" && (isFileRequestType(requestType) || isBatchRequestType(requestType) || isContainerRequestType(requestType) || isModellessVideoRequestType(requestType) || isPassthroughRequestType(requestType))) || requestType == schemas.ListModelsRequest if skipModelCheck { // When skipping model check: just verify keys are enabled and have values - for _, k := range keys { + for _, key := range keys { // Skip disabled keys - if k.Enabled != nil && !*k.Enabled { + if key.Enabled != nil && !*key.Enabled { continue } - if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { - supportedKeys = append(supportedKeys, k) + if err := validateKey(baseProviderType, &key); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) + continue + } + if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { + supportedKeys = append(supportedKeys, key) } } } else { - // When NOT skipping model check: do full model/deployment filtering + // When NOT skipping model check: do full model filtering for _, key := range keys { // Skip disabled keys if key.Enabled != nil && !*key.Enabled { continue } - hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) - var modelSupported bool - if len(key.BlacklistedModels) > 0 && slices.Contains(key.BlacklistedModels, model) { - modelSupported = false - } else { - modelSupported = (len(key.Models) == 0 && hasValue) || (slices.Contains(key.Models, model) && hasValue) + if err := validateKey(baseProviderType, &key); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) + continue } - // Additional deployment checks for Azure, Bedrock and Vertex - deploymentSupported := true - if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { - // For Azure, check if deployment exists for this model - if len(key.AzureKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.AzureKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil { - // For Bedrock, check if deployment exists for this model - if len(key.BedrockKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.BedrockKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Vertex && key.VertexKeyConfig != nil { - // For Vertex, check if deployment exists for this model - if len(key.VertexKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.VertexKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Replicate && key.ReplicateKeyConfig != nil { - // For Replicate, check if deployment exists for this model - if len(key.ReplicateKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.ReplicateKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { - // For VLLM, check if model name matches the key's configured model + hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) + // ["*"] = allow all models; [] = deny all; specific list = allow only listed + // NOTE: Model filtering uses the original requested model (which may be an alias). + // key.Models and key.BlacklistedModels must therefore be expressed in alias keys. + // The provider-specific identifier is resolved later in the handler closure via key.Aliases.Resolve(model). + modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model) + if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { if key.VLLMKeyConfig.ModelName != "" { - deploymentSupported = (key.VLLMKeyConfig.ModelName == model) + modelSupported = modelSupported && (key.VLLMKeyConfig.ModelName == model) } } - - if modelSupported && deploymentSupported { + if modelSupported { supportedKeys = append(supportedKeys, key) } } } if len(supportedKeys) == 0 { - if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock || baseProviderType == schemas.Vertex || baseProviderType == schemas.Replicate || baseProviderType == schemas.VLLM { - return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model) - } - return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) + return nil, false, fmt.Errorf("no keys found that support model: %s", model) } - // Key ID takes priority over key name when both are present + // Explicit key ID takes priority over key name — pin to that key, no rotation. if ctx != nil { if keyID, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyID).(string); ok { if keyID = strings.TrimSpace(keyID); keyID != "" { for _, key := range supportedKeys { if key.ID == keyID { - return key, nil + return []schemas.Key{key}, false, nil } } - return schemas.Key{}, fmt.Errorf("no supported key found with id %q for provider: %v and model: %s", keyID, providerKey, model) + return nil, false, fmt.Errorf("no supported key found with id %q for provider: %v and model: %s", keyID, providerKey, model) } } if keyName, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyName).(string); ok { if keyName = strings.TrimSpace(keyName); keyName != "" { for _, key := range supportedKeys { if key.Name == keyName { - return key, nil + return []schemas.Key{key}, false, nil } } - return schemas.Key{}, fmt.Errorf("no supported key found with name %q for provider: %v and model: %s", keyName, providerKey, model) + return nil, false, fmt.Errorf("no supported key found with name %q for provider: %v and model: %s", keyName, providerKey, model) } } } + // Single key: no rotation possible, skip session stickiness (no KV write needed). if len(supportedKeys) == 1 { - return supportedKeys[0], nil + return []schemas.Key{supportedKeys[0]}, false, nil } - // Session stickiness: on the first request for a session ID, the randomly - // selected key is persisted in the KV store. Subsequent requests reuse it as - // long as the key remains valid. The sticky-key lookup/selection in this block - // occurs before executeRequestWithRetries, so the same sticky key is - // intentionally applied for the entire session including all retry attempts— - // the selected key is persisted in KV and reused across retries rather than - // re-selected on each attempt. + // Session stickiness: on the first request for a session ID, the randomly selected key is + // persisted in the KV store. Subsequent requests reuse it for the session lifetime. The sticky + // key is intentionally kept fixed across all retry attempts — return it as a single-element + // pool with canRotate=false so rate-limit retries also stay on the same key. sessionID := "" if ctx != nil { if id, ok := ctx.Value(schemas.BifrostContextKeySessionID).(string); ok && id != "" { sessionID = id } } - fallbackIndex := 0 if ctx != nil { fallbackIndex, _ = ctx.Value(schemas.BifrostContextKeyFallbackIndex).(int) @@ -6571,58 +7031,46 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex ttl = schemas.DefaultSessionStickyTTL } - // Try to retrieve existing cached key if cachedKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found { - // Refresh TTL so active sessions do not expire. - err := bifrost.kvStore.SetWithTTL(kvKey, cachedKey.ID, ttl) - if err != nil { + if err := bifrost.kvStore.SetWithTTL(kvKey, cachedKey.ID, ttl); err != nil { bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, cachedKey.ID, err.Error()) } - return cachedKey, nil + return []schemas.Key{cachedKey}, false, nil } else if stale { if _, err := bifrost.kvStore.Delete(kvKey); err != nil { bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error()) } } - // No cached key found (or stale entry deleted), select a new one selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model) if err != nil { - return schemas.Key{}, err + return nil, false, err } - // Atomically set the key only if not already set (first-write-wins) wasSet, err := bifrost.kvStore.SetNXWithTTL(kvKey, selectedKey.ID, ttl) if err != nil { bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, selectedKey.ID, err.Error()) - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - if wasSet { - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - // Another concurrent request won the race, re-read the current key + // Another concurrent request won the race — re-read the persisted key. if currentKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found { - return currentKey, nil + return []schemas.Key{currentKey}, false, nil } else if stale { if _, err := bifrost.kvStore.Delete(kvKey); err != nil { bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error()) } - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - // Fallback: if we can't read the current key, use what we selected - // (shouldn't happen in normal operation, but defensive) - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model) - if err != nil { - return schemas.Key{}, err - } - - return selectedKey, nil + // Normal case: return the full filtered pool with rotation enabled. + return supportedKeys, true, nil } // getCachedKeyFromStore retrieves a key ID from the KV store and looks it up in supportedKeys. @@ -6659,34 +7107,6 @@ func getCachedKeyFromStore(kvStore schemas.KVStore, kvKey string, supportedKeys return schemas.Key{}, false, false } -func WeightedRandomKeySelector(ctx *schemas.BifrostContext, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { - // Use a weighted random selection based on key weights - totalWeight := 0 - for _, key := range keys { - totalWeight += int(key.Weight * 100) // Convert float to int for better performance - } - - // If all keys have zero weight, fall back to uniform random selection - if totalWeight == 0 { - return keys[rand.Intn(len(keys))], nil - } - - // Use global thread-safe random (Go 1.20+) - no allocation, no syscall - randomValue := rand.Intn(totalWeight) - - // Select key based on weight - currentWeight := 0 - for _, key := range keys { - currentWeight += int(key.Weight * 100) - if randomValue < currentWeight { - return key, nil - } - } - - // Fallback to first key if something goes wrong - return keys[0], nil -} - // Shutdown gracefully stops all workers when triggered. // It closes all request channels and waits for workers to exit. func (bifrost *Bifrost) Shutdown() { diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 6944ed1d9d..f5c4cbd6c7 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -59,7 +59,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { // Test immediate success t.Run("ImmediateSuccess", func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ return "success", nil } @@ -68,6 +68,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -89,7 +90,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { // Test success after retries t.Run("SuccessAfterRetries", func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ if callCount <= 2 { // First two calls fail with retryable error @@ -103,6 +104,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -130,7 +132,7 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { logger := NewDefaultLogger(schemas.LogLevelError) t.Run("ExceedsMaxRetries", func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ // Always fail with retryable error return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) @@ -140,6 +142,7 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -196,7 +199,7 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ return "", tc.error } @@ -205,6 +208,7 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -272,7 +276,7 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ return "", tc.error } @@ -281,6 +285,7 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -511,7 +516,7 @@ func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) { var attemptCounts []int callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ attemptCounts = append(attemptCounts, callCount) @@ -528,6 +533,7 @@ func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -607,8 +613,8 @@ func TestHandleProviderRequest_OCROperationNotAllowed(t *testing.T) { if err.ExtraFields.RequestType != schemas.OCRRequest { t.Fatalf("expected OCR request type, got %q", err.ExtraFields.RequestType) } - if err.ExtraFields.ModelRequested != "custom-mistral/mistral-ocr-latest" { - t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.ModelRequested) + if err.ExtraFields.OriginalModelRequested != "custom-mistral/mistral-ocr-latest" { + t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.OriginalModelRequested) } } @@ -813,15 +819,15 @@ func (m *mockKVStore) Delete(key string) (bool, error) { return false, nil } -// Test selectKeyFromProviderForModel with session stickiness +// Test selectKeyFromProviderForModelWithPool with session stickiness func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) { kvStore := newMockKVStore() account := NewMockAccount() account.AddProvider(schemas.OpenAI, 5, 1000) // Use 2 keys so we hit the keySelector path (single key returns early) account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ - {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Weight: 1}, - {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Weight: 1}, + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, }) var keySelectorCalls int @@ -844,13 +850,16 @@ func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) { bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) bfCtx.SetValue(schemas.BifrostContextKeySessionID, "sess-123") - // First call: cache miss, keySelector runs, key stored - key1, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + // First call: cache miss, keySelector runs, key stored; returns single-element pool (canRotate=false) + keys1, canRotate1, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { - t.Fatalf("first selectKeyFromProviderForModel: %v", err) + t.Fatalf("first selectKeyFromProviderForModelWithPool: %v", err) + } + if canRotate1 { + t.Error("first call: canRotate should be false for session-sticky request") } - if key1.ID != "key-a" { - t.Errorf("first call: expected key-a, got %s", key1.ID) + if len(keys1) != 1 || keys1[0].ID != "key-a" { + t.Errorf("first call: expected [key-a], got %v", keys1) } if keySelectorCalls != 1 { t.Errorf("first call: expected 1 keySelector call, got %d", keySelectorCalls) @@ -863,26 +872,29 @@ func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) { } // Second call: cache hit, same key returned, keySelector NOT called - key2, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + keys2, canRotate2, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { - t.Fatalf("second selectKeyFromProviderForModel: %v", err) + t.Fatalf("second selectKeyFromProviderForModelWithPool: %v", err) } - if key2.ID != "key-a" { - t.Errorf("second call: expected key-a (sticky), got %s", key2.ID) + if canRotate2 { + t.Error("second call: canRotate should be false for session-sticky request") + } + if len(keys2) != 1 || keys2[0].ID != "key-a" { + t.Errorf("second call: expected [key-a] (sticky), got %v", keys2) } if keySelectorCalls != 1 { t.Errorf("second call: keySelector should not run (cache hit), got %d calls", keySelectorCalls) } } -// Test selectKeyFromProviderForModel - no stickiness when session ID absent +// Test selectKeyFromProviderForModelWithPool - no stickiness when session ID absent func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T) { kvStore := newMockKVStore() account := NewMockAccount() account.AddProvider(schemas.OpenAI, 5, 1000) account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ - {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Weight: 1}, - {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Weight: 1}, + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, }) var keySelectorCalls int @@ -903,19 +915,22 @@ func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T } bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) - // No session ID set + // No session ID set — pool is returned with canRotate=true; keySelector is called each time. for i := 0; i < 2; i++ { - key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + pool, canRotate, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { - t.Fatalf("selectKeyFromProviderForModel call %d: %v", i+1, err) + t.Fatalf("selectKeyFromProviderForModelWithPool call %d: %v", i+1, err) + } + if !canRotate { + t.Fatalf("call %d: canRotate should be true without a session id", i+1) } - if key.ID != "key-a" { - t.Fatalf("call %d: expected key-a, got %s", i+1, key.ID) + if len(pool) == 0 { + t.Fatalf("call %d: expected non-empty pool", i+1) } } - if keySelectorCalls != 2 { - t.Errorf("expected 2 keySelector calls without a session id, got %d", keySelectorCalls) + if keySelectorCalls != 0 { + t.Errorf("expected 0 keySelector calls from pool building (no session id), got %d", keySelectorCalls) } // KVStore should not have a sticky entry for an empty session id if _, err := kvStore.Get(buildSessionKey(schemas.OpenAI, "", "gpt-4")); err == nil { @@ -923,6 +938,82 @@ func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T } } +// TestSelectKeyFromProviderForModel_SessionStickinessNoRotation verifies that when a session ID +// is present, rate-limit retries reuse the sticky key rather than rotating to another key. +func TestSelectKeyFromProviderForModel_SessionStickinessNoRotation(t *testing.T) { + kvStore := newMockKVStore() + account := NewMockAccount() + account.AddProvider(schemas.OpenAI, 5, 1000) + account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, + }) + + deterministicSelector := func(ctx *schemas.BifrostContext, keys []schemas.Key, _ schemas.ModelProvider, _ string) (schemas.Key, error) { + return keys[0], nil // always picks key-a when pool includes it + } + + ctx := context.Background() + bifrost, err := Init(ctx, schemas.BifrostConfig{ + Account: account, + Logger: NewDefaultLogger(schemas.LogLevelError), + KVStore: kvStore, + KeySelector: deterministicSelector, + }) + if err != nil { + t.Fatalf("Init failed: %v", err) + } + + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bfCtx.SetValue(schemas.BifrostContextKeySessionID, "sess-sticky") + bfCtx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) + + config := createTestConfig(3, 0, 0) + logger := NewDefaultLogger(schemas.LogLevelError) + + // Build keyProvider the same way requestWorker does. + pool, canRotate, poolErr := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + if poolErr != nil { + t.Fatalf("pool build failed: %v", poolErr) + } + if canRotate { + t.Fatal("expected canRotate=false for session-sticky request") + } + if len(pool) != 1 || pool[0].ID != "key-a" { + t.Fatalf("expected sticky pool=[key-a], got %v", pool) + } + + fixedKey := pool[0] + keyProvider := func(_ map[string]bool) (schemas.Key, error) { return fixedKey, nil } + + // Simulate 3 rate-limit failures then success; all attempts must use key-a. + var usedKeyIDs []string + callCount := 0 + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + usedKeyIDs = append(usedKeyIDs, k.ID) + callCount++ + if callCount <= 3 { + return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) + } + return "ok", nil + } + + result, retryErr := executeRequestWithRetries(bfCtx, config, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if retryErr != nil { + t.Fatalf("expected success, got error: %v", retryErr) + } + if result != "ok" { + t.Errorf("expected 'ok', got %s", result) + } + for i, id := range usedKeyIDs { + if id != "key-a" { + t.Errorf("attempt %d: expected sticky key-a, got %s (full sequence: %v)", i, id, usedKeyIDs) + } + } +} + func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { account := NewMockAccount() account.AddProvider(schemas.OpenAI, 5, 1000) @@ -941,7 +1032,7 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, }) - _, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + _, _, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err == nil { t.Fatal("expected error when model is only blacklisted") } @@ -958,7 +1049,7 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { BlacklistedModels: []string{"gpt-4"}, }, }) - _, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + _, _, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err == nil { t.Fatal("expected error when model is both allowed and blacklisted") } @@ -967,14 +1058,200 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { t.Run("second key used when first blacklists", func(t *testing.T) { account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, - {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1}, + {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1, Models: []string{"*"}}, }) - key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + pool, canRotate, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { t.Fatalf("unexpected error: %v", err) } - if key.ID != "k2" { - t.Fatalf("expected k2, got %s", key.ID) + // After filtering, only k2 remains — single key returns canRotate=false. + if canRotate { + t.Fatal("expected canRotate=false for single-key pool after filtering") + } + if len(pool) != 1 || pool[0].ID != "k2" { + t.Fatalf("expected pool=[k2], got %v", pool) + } + }) +} + +// Test key rotation in executeRequestWithRetries on rate-limit errors +func TestExecuteRequestWithRetries_KeyRotation(t *testing.T) { + config := createTestConfig(3, 0, 0) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) + logger := NewDefaultLogger(schemas.LogLevelError) + + keys := []schemas.Key{ + {ID: "k1", Name: "K1"}, + {ID: "k2", Name: "K2"}, + {ID: "k3", Name: "K3"}, + } + + t.Run("RotatesKeyOnRateLimitRetry", func(t *testing.T) { + var selectedKeyIDs []string + keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) { + for _, k := range keys { + if !usedKeyIDs[k.ID] { + return k, nil + } + } + // Fresh round + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + return keys[0], nil + } + + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + selectedKeyIDs = append(selectedKeyIDs, k.ID) + // First two calls rate-limit, third succeeds + if len(selectedKeyIDs) <= 2 { + return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) + } + return "success", nil + } + + result, err := executeRequestWithRetries(ctx, config, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if result != "success" { + t.Errorf("expected 'success', got %s", result) + } + if len(selectedKeyIDs) != 3 { + t.Fatalf("expected 3 attempts, got %d", len(selectedKeyIDs)) + } + // Each attempt should use a different key + seen := map[string]struct{}{} + for _, id := range selectedKeyIDs { + seen[id] = struct{}{} + } + if len(seen) != len(selectedKeyIDs) { + t.Errorf("expected distinct keys per rate-limit retry, got %v", selectedKeyIDs) + } + }) + + t.Run("SameKeyOnNetworkError", func(t *testing.T) { + var selectedKeyIDs []string + keyProviderCalls := 0 + keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) { + keyProviderCalls++ + for _, k := range keys { + if !usedKeyIDs[k.ID] { + return k, nil + } + } + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + return keys[0], nil + } + + callCount := 0 + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + selectedKeyIDs = append(selectedKeyIDs, k.ID) + callCount++ + if callCount <= 2 { + return "", createBifrostError(schemas.ErrProviderDoRequest, nil, nil, false) + } + return "success", nil + } + + result, err := executeRequestWithRetries(ctx, config, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if result != "success" { + t.Errorf("expected 'success', got %s", result) + } + if len(selectedKeyIDs) != 3 { + t.Fatalf("expected 3 attempts, got %d", len(selectedKeyIDs)) + } + if keyProviderCalls != 1 { + t.Fatalf("expected keyProvider to be called once for network retries, got %d", keyProviderCalls) + } + // All attempts should use the same key (network error = same key) + for i := 1; i < len(selectedKeyIDs); i++ { + if selectedKeyIDs[i] != selectedKeyIDs[0] { + t.Errorf("expected same key for all network-error retries, got %v", selectedKeyIDs) + } + } + }) + + t.Run("CyclesFreshRoundWhenPoolExhausted", func(t *testing.T) { + var selectedKeyIDs []string + // 3 keys, 6 retries — should cycle through all 3 keys twice + config6 := createTestConfig(5, 0, 0) // 5 retries = 6 total attempts + keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) { + available := make([]schemas.Key, 0) + for _, k := range keys { + if !usedKeyIDs[k.ID] { + available = append(available, k) + } + } + if len(available) == 0 { + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + available = keys + } + return available[0], nil + } + + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + selectedKeyIDs = append(selectedKeyIDs, k.ID) + return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) + } + + executeRequestWithRetries(ctx, config6, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if len(selectedKeyIDs) != 6 { + t.Fatalf("expected 6 attempts (1 initial + 5 retries), got %d", len(selectedKeyIDs)) + } + // First cycle: k1, k2, k3; second cycle: k1, k2, k3 + expected := []string{"k1", "k2", "k3", "k1", "k2", "k3"} + for i, id := range selectedKeyIDs { + if id != expected[i] { + t.Errorf("attempt %d: expected key %s, got %s (full sequence: %v)", i, expected[i], id, selectedKeyIDs) + } + } + }) + + t.Run("NilKeyProviderUsesZeroKey", func(t *testing.T) { + cleanCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + cleanCtx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) + + var receivedKey schemas.Key + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + receivedKey = k + return "ok", nil + } + + result, err := executeRequestWithRetries(cleanCtx, config, handler, nil, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "ok" { + t.Errorf("expected 'ok', got %s", result) + } + if receivedKey.ID != "" { + t.Errorf("expected zero Key when keyProvider is nil, got ID=%s", receivedKey.ID) + } + if trail, ok := cleanCtx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok && len(trail) > 0 { + t.Fatalf("expected no attempt trail for nil keyProvider, got %v", trail) + } + if selectedID, _ := cleanCtx.Value(schemas.BifrostContextKeySelectedKeyID).(string); selectedID != "" { + t.Fatalf("expected empty selected key id, got %q", selectedID) + } + if selectedName, _ := cleanCtx.Value(schemas.BifrostContextKeySelectedKeyName).(string); selectedName != "" { + t.Fatalf("expected empty selected key name, got %q", selectedName) } }) } diff --git a/core/go.mod b/core/go.mod index b85c403ec6..013296f021 100644 --- a/core/go.mod +++ b/core/go.mod @@ -7,12 +7,12 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 github.com/andybalholm/brotli v1.2.0 - github.com/aws/aws-sdk-go-v2 v1.41.3 + github.com/aws/aws-sdk-go-v2 v1.41.5 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 github.com/aws/aws-sdk-go-v2/config v1.32.11 - github.com/aws/aws-sdk-go-v2/credentials v1.19.11 - github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 - github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 + github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 + github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 github.com/aws/smithy-go v1.24.2 github.com/bytedance/sonic v1.15.0 github.com/fasthttp/websocket v1.5.12 @@ -35,18 +35,18 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 // indirect - github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.2 // indirect github.com/bytedance/gopkg v0.1.3 // indirect diff --git a/core/go.sum b/core/go.sum index a01765f093..685035b381 100644 --- a/core/go.sum +++ b/core/go.sum @@ -16,42 +16,38 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgv github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= -github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= -github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= +github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= -github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= -github.com/aws/aws-sdk-go-v2/credentials v1.19.11/go.mod h1:30yY2zqkMPdrvxBqzI9xQCM+WrlrZKSOpSJEsylVU+8= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19/go.mod h1:FpZN2QISLdEBWkayloda+sZjVJL+e9Gl0k1SyTgcswU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 h1:CjMzUs78RDDv4ROu3JnJn/Ig1r6ZD7/T2DXLLRpejic= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16/go.mod h1:uVW4OLBqbJXSHJYA9svT9BluSvvwbzLQ2Crf6UPzR3c= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 h1:DIBqIrJ7hv+e4CmIk2z3pyKT+3B6qVMgRsawHiR3qso= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7/go.mod h1:vLm00xmBke75UmpNvOcZQ/Q30ZFjbczeLFqGx5urmGo= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19/go.mod h1:/rARO8psX+4sfjUQXp5LLifjUt8DuATZ31WptNJTyQA= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 h1:NSbvS17MlI2lurYgXnCOLvCFX38sBW4eiVER7+kkgsU= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16/go.mod h1:SwT8Tmqd4sA6G1qaGdzWCJN99bUmPGHfRwwq3G5Qb+A= -github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 h1:SWTxh/EcUCDVqi/0s26V6pVUq0BBG7kx0tDTmF/hCgA= -github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0/go.mod h1:79S2BdqCJpScXZA2y+cpZuocWsjGjJINyXnOsf5DTz8= -github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 h1:Y2cAXlClHsXkkOvWZFXATr34b0hxxloeQu/pAZz2row= -github.com/aws/aws-sdk-go-v2/service/signin v1.0.7/go.mod h1:idzZ7gmDeqeNrSPkdbtMp9qWMgcBwykA7P7Rzh5DXVU= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.12/go.mod h1:fEWYKTRGoZNl8tZ77i61/ccwOMJdGxwOhWCkp6TXAr0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16/go.mod h1:Jic/xv0Rq/pFNCh3WwpH4BEqdbSAl+IyHro8LbibHD8= -github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= -github.com/aws/aws-sdk-go-v2/service/sts v1.41.8/go.mod h1:Xgx+PR1NUOjNmQY+tRMnouRp83JRM8pRMw/vCaVhPkI= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOqFJuNnO8WwaIRVxzQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 h1:JRaIgADQS/U6uXDqlPiefP32yXTda7Kqfx+LgspooZM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 h1:ZlvrNcHSFFWURB8avufQq9gFsheUgjVD9536obIknfM= +github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 h1:HwxWTbTrIHm5qY+CAEur0s/figc3qwvLWsNkF4RPToo= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= diff --git a/core/internal/llmtests/account.go b/core/internal/llmtests/account.go index 0632f16b6b..e4044e0e30 100644 --- a/core/internal/llmtests/account.go +++ b/core/internal/llmtests/account.go @@ -15,80 +15,87 @@ import ( const Concurrency = 4 +// Replicate test key names (see replicateProviderTestKeys). +// ListModels uses the deployments API via the list-models key; all other operations use the inference key. +const ( + ReplicateKeyNameListModels = "replicate-list-models-deployments" + ReplicateKeyNameInference = "replicate-inference" +) + // ProviderOpenAICustom represents the custom OpenAI provider for testing const ProviderOpenAICustom = schemas.ModelProvider("openai-custom") // TestScenarios defines the comprehensive test scenarios type TestScenarios struct { - TextCompletion bool - TextCompletionStream bool - SimpleChat bool - CompletionStream bool - MultiTurnConversation bool - ToolCalls bool - ToolCallsStreaming bool // Streaming tool calls functionality + TextCompletion bool + TextCompletionStream bool + SimpleChat bool + CompletionStream bool + MultiTurnConversation bool + ToolCalls bool + ToolCallsStreaming bool // Streaming tool calls functionality MultipleToolCalls bool MultipleToolCallsStreaming bool // Streaming multiple tool calls (some providers only return 1 tool call in streaming) End2EndToolCalling bool - AutomaticFunctionCall bool - ImageURL bool - ImageBase64 bool - MultipleImages bool - FileBase64 bool - FileURL bool - CompleteEnd2End bool - SpeechSynthesis bool // Text-to-speech functionality - SpeechSynthesisStream bool // Streaming text-to-speech functionality - Transcription bool // Speech-to-text functionality - TranscriptionStream bool // Streaming speech-to-text functionality - Embedding bool // Embedding functionality - Reasoning bool // Reasoning/thinking functionality via Responses API - PromptCaching bool // Prompt caching functionality - ListModels bool // List available models functionality - ImageGeneration bool // Image generation functionality - ImageGenerationStream bool // Streaming image generation functionality - ImageEdit bool // Image edit functionality - ImageEditStream bool // Streaming image edit functionality - ImageVariation bool // Image variation functionality - ImageVariationStream bool // Streaming image variation functionality (if supported) - VideoGeneration bool // Video generation functionality - VideoRetrieve bool // Video retrieve functionality - VideoRemix bool // Video remix functionality (OpenAI only) - VideoDownload bool // Video download functionality - VideoList bool // Video list functionality - VideoDelete bool // Video delete functionality - BatchCreate bool // Batch API create functionality - BatchList bool // Batch API list functionality - BatchRetrieve bool // Batch API retrieve functionality - BatchCancel bool // Batch API cancel functionality - BatchResults bool // Batch API results functionality - FileUpload bool // File API upload functionality - FileList bool // File API list functionality - FileRetrieve bool // File API retrieve functionality - FileDelete bool // File API delete functionality - FileContent bool // File API content download functionality - FileBatchInput bool // Whether batch create supports file-based input (InputFileID) - CountTokens bool // Count tokens functionality - ChatAudio bool // Chat completion with audio input/output functionality - StructuredOutputs bool // Structured outputs (JSON schema) functionality - WebSearchTool bool // Web search tool functionality - ContainerCreate bool // Container API create functionality - ContainerList bool // Container API list functionality - ContainerRetrieve bool // Container API retrieve functionality - ContainerDelete bool // Container API delete functionality - ContainerFileCreate bool // Container File API create functionality - ContainerFileList bool // Container File API list functionality - ContainerFileRetrieve bool // Container File API retrieve functionality - ContainerFileContent bool // Container File API content functionality - ContainerFileDelete bool // Container File API delete functionality - PassThroughExtraParams bool // Pass through extra params functionality - Rerank bool // Rerank functionality - PassthroughAPI bool // Raw HTTP passthrough API (Passthrough + PassthroughStream) - WebSocketResponses bool // WebSocket Responses API mode - Realtime bool // Realtime API (bidirectional audio/text) - Compaction bool // Server-side compaction (context management) - InterleavedThinking bool // Interleaved thinking between tool calls (beta) - FastMode bool // Fast mode for Opus 4.6 (beta: research preview) + AutomaticFunctionCall bool + ImageURL bool + ImageBase64 bool + MultipleImages bool + FileBase64 bool + FileURL bool + CompleteEnd2End bool + SpeechSynthesis bool // Text-to-speech functionality + SpeechSynthesisStream bool // Streaming text-to-speech functionality + Transcription bool // Speech-to-text functionality + TranscriptionStream bool // Streaming speech-to-text functionality + Embedding bool // Embedding functionality + Reasoning bool // Reasoning/thinking functionality via Responses API + PromptCaching bool // Prompt caching functionality + ListModels bool // List available models functionality + ImageGeneration bool // Image generation functionality + ImageGenerationStream bool // Streaming image generation functionality + ImageEdit bool // Image edit functionality + ImageEditStream bool // Streaming image edit functionality + ImageVariation bool // Image variation functionality + ImageVariationStream bool // Streaming image variation functionality (if supported) + VideoGeneration bool // Video generation functionality + VideoRetrieve bool // Video retrieve functionality + VideoRemix bool // Video remix functionality (OpenAI only) + VideoDownload bool // Video download functionality + VideoList bool // Video list functionality + VideoDelete bool // Video delete functionality + BatchCreate bool // Batch API create functionality + BatchList bool // Batch API list functionality + BatchRetrieve bool // Batch API retrieve functionality + BatchCancel bool // Batch API cancel functionality + BatchResults bool // Batch API results functionality + FileUpload bool // File API upload functionality + FileList bool // File API list functionality + FileRetrieve bool // File API retrieve functionality + FileDelete bool // File API delete functionality + FileContent bool // File API content download functionality + FileBatchInput bool // Whether batch create supports file-based input (InputFileID) + CountTokens bool // Count tokens functionality + ChatAudio bool // Chat completion with audio input/output functionality + StructuredOutputs bool // Structured outputs (JSON schema) functionality + WebSearchTool bool // Web search tool functionality + ContainerCreate bool // Container API create functionality + ContainerList bool // Container API list functionality + ContainerRetrieve bool // Container API retrieve functionality + ContainerDelete bool // Container API delete functionality + ContainerFileCreate bool // Container File API create functionality + ContainerFileList bool // Container File API list functionality + ContainerFileRetrieve bool // Container File API retrieve functionality + ContainerFileContent bool // Container File API content functionality + ContainerFileDelete bool // Container File API delete functionality + PassThroughExtraParams bool // Pass through extra params functionality + Rerank bool // Rerank functionality + PassthroughAPI bool // Raw HTTP passthrough API (Passthrough + PassthroughStream) + WebSocketResponses bool // WebSocket Responses API mode + Realtime bool // Realtime API (bidirectional audio/text) + Compaction bool // Server-side compaction (context management) + InterleavedThinking bool // Interleaved thinking between tool calls (beta) + FastMode bool // Fast mode for Opus 4.6 (beta: research preview) EagerInputStreaming bool // Fine-grained tool input streaming (Anthropic fine-grained-tool-streaming-2025-05-14) ServerToolsViaOpenAIEndpoint bool // Anthropic server-tool shapes in tools[] via /v1/chat/completions (web_search / web_fetch / code_execution) } @@ -175,6 +182,34 @@ func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.Mod }, nil } +// replicateProviderTestKeys returns the two Replicate keys used by comprehensive tests: +func replicateProviderTestKeys() []schemas.Key { + return []schemas.Key{ + { + Name: ReplicateKeyNameListModels, + Value: *schemas.NewEnvVar("env.REPLICATE_API_KEY"), + Models: []string{"*"}, + Weight: 0, + UseForBatchAPI: bifrost.Ptr(false), + ReplicateKeyConfig: &schemas.ReplicateKeyConfig{UseDeploymentsEndpoint: true}, + }, + { + Name: ReplicateKeyNameInference, + Value: *schemas.NewEnvVar("env.REPLICATE_API_KEY"), + Models: []string{"*"}, + Weight: 1.0, + UseForBatchAPI: bifrost.Ptr(true), + ReplicateKeyConfig: nil, + }, + } +} + +// ReplicateDirectKeyForListModels returns the key used for Replicate ListModels (deployments endpoint). +// List-models tests set it on the context as schemas.BifrostContextKeyDirectKey so Bifrost passes only this key. +func ReplicateDirectKeyForListModels() schemas.Key { + return replicateProviderTestKeys()[0] +} + // GetKeysForProvider returns the API keys and associated models for a given provider. func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { switch providerKey { @@ -182,7 +217,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -191,7 +226,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), // Use GROQ API key for OpenAI-compatible endpoint - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -200,7 +235,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.ANTHROPIC_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -208,38 +243,38 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, case schemas.Bedrock: return []schemas.Key{ { - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_ARN"), - Deployments: map[string]string{ - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, }, { - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_BEDROCK_ARN"), - Deployments: map[string]string{ - "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, @@ -258,7 +293,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.COHERE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -267,20 +302,20 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o", + "gpt-4o-backup": "gpt-4o-3", + "claude-opus-4-5": "claude-opus-4-5", + "o1": "o1", + "gpt-image-1": "gpt-image-1", + "text-embedding-ada-002": "text-embedding-ada-002", + "sora-2": "sora-2", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), - APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o", - "gpt-4o-backup": "gpt-4o-3", - "claude-opus-4-5": "claude-opus-4-5", - "o1": "o1", - "gpt-image-1": "gpt-image-1", - "text-embedding-ada-002": "text-embedding-ada-002", - "sora-2": "sora-2", - }, + Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), + APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), ClientID: schemas.NewEnvVar("env.AZURE_CLIENT_ID"), ClientSecret: schemas.NewEnvVar("env.AZURE_CLIENT_SECRET"), TenantID: schemas.NewEnvVar("env.AZURE_TENANT_ID"), @@ -289,16 +324,17 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, }, { Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "whisper": "whisper", + "whisper-1": "whisper", + "gpt-4o-mini-tts": "gpt-4o-mini-tts", + "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "whisper": "whisper", - "gpt-4o-mini-tts": "gpt-4o-mini-tts", - "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", - }, }, }, }, nil @@ -308,7 +344,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), - Models: []string{"text-multilingual-embedding-002", "google/gemini-2.0-flash-001", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004"}, + Models: []string{"text-multilingual-embedding-002", "gemini-2.5-pro", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004"}, Weight: 1.0, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), @@ -332,15 +368,15 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), Models: []string{"claude-sonnet-4-5", "claude-4.5-haiku", "claude-opus-4-5"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-4.5-haiku": "claude-haiku-4-5@20251001", + "claude-opus-4-5": "claude-opus-4-5", + }, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), Region: *schemas.NewEnvVar(getEnvWithDefault("VERTEX_REGION_ANTHROPIC", "us-east5")), AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"), - Deployments: map[string]string{ - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-4.5-haiku": "claude-haiku-4-5@20251001", - "claude-opus-4-5": "claude-opus-4-5", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, @@ -349,7 +385,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.MISTRAL_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -358,7 +394,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.GROQ_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -367,7 +403,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.PARASAIL_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -376,7 +412,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.ELEVENLABS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -385,7 +421,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.PERPLEXITY_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -394,7 +430,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.CEREBRAS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -403,7 +439,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.GEMINI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -412,7 +448,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENROUTER_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -421,7 +457,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.HUGGING_FACE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -430,7 +466,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.NEBIUS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -439,25 +475,18 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.XAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, }, nil case schemas.Replicate: - return []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.REPLICATE_API_KEY"), - Models: []string{}, - Weight: 1.0, - UseForBatchAPI: bifrost.Ptr(true), - }, - }, nil + return replicateProviderTestKeys(), nil case schemas.Runway: return []schemas.Key{ { Value: *schemas.NewEnvVar("env.RUNWAY_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -828,53 +857,54 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageVariationModel: "dall-e-2", ChatAudioModel: "gpt-4o-mini-audio-preview", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - TextCompletionStream: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + TextCompletionStream: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, // OpenAI supports TTS - SpeechSynthesisStream: true, // OpenAI supports streaming TTS - Transcription: true, // OpenAI supports STT with Whisper - TranscriptionStream: true, // OpenAI supports streaming STT - ImageGeneration: true, // OpenAI supports image generation with DALL-E - ImageGenerationStream: true, // OpenAI supports streaming image generation - ImageEdit: true, // OpenAI supports image editing - ImageEditStream: true, // OpenAI supports streaming image editing - ImageVariation: true, // OpenAI supports image variation - ImageVariationStream: false, // OpenAI does not support streaming image variation - Embedding: true, - Reasoning: true, // OpenAI supports reasoning via o1 models - ListModels: true, - BatchCreate: true, // OpenAI supports batch API - BatchList: true, // OpenAI supports batch API - BatchRetrieve: true, // OpenAI supports batch API - BatchCancel: true, // OpenAI supports batch API - BatchResults: true, // OpenAI supports batch API - FileUpload: true, // OpenAI supports file API - FileList: true, // OpenAI supports file API - FileRetrieve: true, // OpenAI supports file API - FileDelete: true, // OpenAI supports file API - FileContent: true, // OpenAI supports file API - ChatAudio: true, // OpenAI supports chat audio - ContainerCreate: true, // OpenAI supports container API - ContainerList: true, // OpenAI supports container API - ContainerRetrieve: true, // OpenAI supports container API - ContainerDelete: true, // OpenAI supports container API - ContainerFileCreate: true, // OpenAI supports container file API - ContainerFileList: true, // OpenAI supports container file API - ContainerFileRetrieve: true, // OpenAI supports container file API - ContainerFileContent: true, // OpenAI supports container file API - ContainerFileDelete: true, // OpenAI supports container file API + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, // OpenAI supports TTS + SpeechSynthesisStream: true, // OpenAI supports streaming TTS + Transcription: true, // OpenAI supports STT with Whisper + TranscriptionStream: true, // OpenAI supports streaming STT + ImageGeneration: true, // OpenAI supports image generation with DALL-E + ImageGenerationStream: true, // OpenAI supports streaming image generation + ImageEdit: true, // OpenAI supports image editing + ImageEditStream: true, // OpenAI supports streaming image editing + ImageVariation: true, // OpenAI supports image variation + ImageVariationStream: false, // OpenAI does not support streaming image variation + Embedding: true, + Reasoning: true, // OpenAI supports reasoning via o1 models + ListModels: true, + BatchCreate: true, // OpenAI supports batch API + BatchList: true, // OpenAI supports batch API + BatchRetrieve: true, // OpenAI supports batch API + BatchCancel: true, // OpenAI supports batch API + BatchResults: true, // OpenAI supports batch API + FileUpload: true, // OpenAI supports file API + FileList: true, // OpenAI supports file API + FileRetrieve: true, // OpenAI supports file API + FileDelete: true, // OpenAI supports file API + FileContent: true, // OpenAI supports file API + ChatAudio: true, // OpenAI supports chat audio + ContainerCreate: true, // OpenAI supports container API + ContainerList: true, // OpenAI supports container API + ContainerRetrieve: true, // OpenAI supports container API + ContainerDelete: true, // OpenAI supports container API + ContainerFileCreate: true, // OpenAI supports container file API + ContainerFileList: true, // OpenAI supports container file API + ContainerFileRetrieve: true, // OpenAI supports container file API + ContainerFileContent: true, // OpenAI supports container file API + ContainerFileDelete: true, // OpenAI supports container file API }, Fallbacks: []schemas.Fallback{ {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, @@ -885,37 +915,38 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "claude-3-7-sonnet-20250219", TextModel: "", // Anthropic doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - PromptCaching: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Anthropic does not support image editing - ImageEditStream: false, // Anthropic does not support streaming image editing - ImageVariation: false, // Anthropic does not support image variation - ImageVariationStream: false, // Anthropic does not support streaming image variation - ListModels: true, - BatchCreate: true, // Anthropic supports batch API - BatchList: true, // Anthropic supports batch API - BatchRetrieve: true, // Anthropic supports batch API - BatchCancel: true, // Anthropic supports batch API - BatchResults: true, // Anthropic supports batch API + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + PromptCaching: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Anthropic does not support image editing + ImageEditStream: false, // Anthropic does not support streaming image editing + ImageVariation: false, // Anthropic does not support image variation + ImageVariationStream: false, // Anthropic does not support streaming image variation + ListModels: true, + BatchCreate: true, // Anthropic supports batch API + BatchList: true, // Anthropic supports batch API + BatchRetrieve: true, // Anthropic supports batch API + BatchCancel: true, // Anthropic supports batch API + BatchResults: true, // Anthropic supports batch API }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -928,42 +959,43 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageEditModel: "amazon.titan-image-generator-v1", ImageVariationModel: "amazon.titan-image-generator-v1", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported for Claude - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported for Claude + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - PromptCaching: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: true, // Bedrock supports image editing - ImageEditStream: false, // Bedrock does not support streaming image editing - ImageVariation: true, // Bedrock supports image variation - ImageVariationStream: false, // Bedrock does not support streaming image variation - ListModels: true, - BatchCreate: true, // Bedrock supports batch via Model Invocation Jobs (requires S3 config) - BatchList: true, // Bedrock supports listing batch jobs - BatchRetrieve: true, // Bedrock supports retrieving batch jobs - BatchCancel: true, // Bedrock supports stopping batch jobs - BatchResults: true, // Bedrock batch results via S3 - FileUpload: true, // Bedrock file upload to S3 (requires S3 config) - FileList: true, // Bedrock file list from S3 (requires S3 config) - FileRetrieve: true, // Bedrock file retrieve from S3 (requires S3 config) - FileDelete: true, // Bedrock file delete from S3 (requires S3 config) - FileContent: true, // Bedrock file content from S3 (requires S3 config) + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + PromptCaching: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: true, // Bedrock supports image editing + ImageEditStream: false, // Bedrock does not support streaming image editing + ImageVariation: true, // Bedrock supports image variation + ImageVariationStream: false, // Bedrock does not support streaming image variation + ListModels: true, + BatchCreate: true, // Bedrock supports batch via Model Invocation Jobs (requires S3 config) + BatchList: true, // Bedrock supports listing batch jobs + BatchRetrieve: true, // Bedrock supports retrieving batch jobs + BatchCancel: true, // Bedrock supports stopping batch jobs + BatchResults: true, // Bedrock batch results via S3 + FileUpload: true, // Bedrock file upload to S3 (requires S3 config) + FileList: true, // Bedrock file list from S3 (requires S3 config) + FileRetrieve: true, // Bedrock file retrieve from S3 (requires S3 config) + FileDelete: true, // Bedrock file delete from S3 (requires S3 config) + FileContent: true, // Bedrock file content from S3 (requires S3 config) }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -974,31 +1006,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "command-a-03-2025", TextModel: "", // Cohere focuses on chat Scenarios: TestScenarios{ - TextCompletion: false, // Not typical for Cohere - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: false, // May not support automatic - ImageURL: false, // Check if supported - ImageBase64: false, // Check if supported - MultipleImages: false, // Check if supported - CompleteEnd2End: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Cohere does not support image editing - ImageEditStream: false, // Cohere does not support streaming image editing - ImageVariation: false, // Cohere does not support image variation - ImageVariationStream: false, // Cohere does not support streaming image variation - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: false, // May not support automatic + ImageURL: false, // Check if supported + ImageBase64: false, // Check if supported + MultipleImages: false, // Check if supported + CompleteEnd2End: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Cohere does not support image editing + ImageEditStream: false, // Cohere does not support streaming image editing + ImageVariation: false, // Cohere does not support image variation + ImageVariationStream: false, // Cohere does not support streaming image variation + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1014,42 +1047,43 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "gpt-image-1", ImageEditModel: "dall-e-2", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, // Supported via gpt-4o-mini-tts - SpeechSynthesisStream: true, // Supported via gpt-4o-mini-tts - Transcription: true, // Supported via whisper-1 - TranscriptionStream: false, // Not properly supported yet by Azure - Embedding: true, - ImageGeneration: false, // Skipped for Azure - ImageGenerationStream: false, // Skipped for Azure - ImageEdit: true, // Azure supports image editing - ImageEditStream: true, // Azure supports streaming image editing - ImageVariation: false, // Azure does not support image variation - ImageVariationStream: false, // Azure does not support streaming image variation - ListModels: true, - BatchCreate: true, // Azure supports batch API - BatchList: true, // Azure supports batch API - BatchRetrieve: true, // Azure supports batch API - BatchCancel: true, // Azure supports batch API - BatchResults: true, // Azure supports batch API - FileUpload: true, // Azure supports file API - FileList: true, // Azure supports file API - FileRetrieve: true, // Azure supports file API - FileDelete: true, // Azure supports file API - FileContent: true, // Azure supports file API - ChatAudio: true, // Azure supports chat audio + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, // Supported via gpt-4o-mini-tts + SpeechSynthesisStream: true, // Supported via gpt-4o-mini-tts + Transcription: true, // Supported via whisper-1 + TranscriptionStream: false, // Not properly supported yet by Azure + Embedding: true, + ImageGeneration: false, // Skipped for Azure + ImageGenerationStream: false, // Skipped for Azure + ImageEdit: true, // Azure supports image editing + ImageEditStream: true, // Azure supports streaming image editing + ImageVariation: false, // Azure does not support image variation + ImageVariationStream: false, // Azure does not support streaming image variation + ListModels: true, + BatchCreate: true, // Azure supports batch API + BatchList: true, // Azure supports batch API + BatchRetrieve: true, // Azure supports batch API + BatchCancel: true, // Azure supports batch API + BatchResults: true, // Azure supports batch API + FileUpload: true, // Azure supports file API + FileList: true, // Azure supports file API + FileRetrieve: true, // Azure supports file API + FileDelete: true, // Azure supports file API + FileContent: true, // Azure supports file API + ChatAudio: true, // Azure supports chat audio }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1063,31 +1097,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "imagen-4.0-generate-001", ImageEditModel: "imagen-4.0-generate-001", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, // Vertex supports image editing - ImageEditStream: false, // Vertex does not support streaming image editing - ImageVariation: false, // Vertex does not support image variation - ImageVariationStream: false, // Vertex does not support streaming image variation - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, // Vertex supports image editing + ImageEditStream: false, // Vertex does not support streaming image editing + ImageVariation: false, // Vertex does not support image variation + ImageVariationStream: false, // Vertex does not support streaming image variation + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1099,30 +1134,31 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "", // Mistral focuses on chat TranscriptionModel: "voxtral-mini-latest", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: true, // Supported via voxtral-mini-latest - TranscriptionStream: true, // Supported via voxtral-mini-latest - Embedding: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Mistral does not support image editing - ImageEditStream: false, // Mistral does not support streaming image editing - ImageVariation: false, // Mistral does not support image variation - ImageVariationStream: false, // Mistral does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: true, // Supported via voxtral-mini-latest + TranscriptionStream: true, // Supported via voxtral-mini-latest + Embedding: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Mistral does not support image editing + ImageEditStream: false, // Mistral does not support streaming image editing + ImageVariation: false, // Mistral does not support image variation + ImageVariationStream: false, // Mistral does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1133,31 +1169,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama3.2", TextModel: "", // Ollama focuses on chat Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Ollama does not support image editing - ImageEditStream: false, // Ollama does not support streaming image editing - ImageVariation: false, // Ollama does not support image variation - ImageVariationStream: false, // Ollama does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Ollama does not support image editing + ImageEditStream: false, // Ollama does not support streaming image editing + ImageVariation: false, // Ollama does not support image variation + ImageVariationStream: false, // Ollama does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1168,31 +1205,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama-3.3-70b-versatile", TextModel: "", // Groq doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Groq does not support image editing - ImageEditStream: false, // Groq does not support streaming image editing - ImageVariation: false, // Groq does not support image variation - ImageVariationStream: false, // Groq does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Groq does not support image editing + ImageEditStream: false, // Groq does not support streaming image editing + ImageVariation: false, // Groq does not support image variation + ImageVariationStream: false, // Groq does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1233,31 +1271,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama-3.3-70b-versatile", TextModel: "", // Custom OpenAI instance doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, - SimpleChat: true, // Enable simple chat for testing - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, + SimpleChat: true, // Enable simple chat for testing + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, // ProviderOpenAICustom does not support image generation - ImageGenerationStream: false, // ProviderOpenAICustom does not support streaming image generation - ImageEdit: false, // ProviderOpenAICustom does not support image editing - ImageEditStream: false, // ProviderOpenAICustom does not support streaming image editing - ImageVariation: false, // ProviderOpenAICustom does not support image variation - ImageVariationStream: false, // ProviderOpenAICustom does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, // ProviderOpenAICustom does not support image generation + ImageGenerationStream: false, // ProviderOpenAICustom does not support streaming image generation + ImageEdit: false, // ProviderOpenAICustom does not support image editing + ImageEditStream: false, // ProviderOpenAICustom does not support streaming image editing + ImageVariation: false, // ProviderOpenAICustom does not support image variation + ImageVariationStream: false, // ProviderOpenAICustom does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1273,41 +1312,42 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "imagen-4.0-generate-001", ImageEditModel: "imagen-4.0-generate-001", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, - SpeechSynthesisStream: true, - Transcription: true, - TranscriptionStream: true, - Embedding: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, // Gemini supports image editing - ImageEditStream: false, // Gemini does not support streaming image editing - ImageVariation: false, // Gemini does not support image variation - ImageVariationStream: false, // Gemini does not support streaming image variation - ListModels: true, - BatchCreate: true, - BatchList: true, - BatchRetrieve: true, - BatchCancel: true, - BatchResults: true, - FileUpload: true, - FileList: true, - FileRetrieve: true, - FileDelete: true, - FileContent: false, // Gemini doesn't support direct content download + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, // Gemini supports image editing + ImageEditStream: false, // Gemini does not support streaming image editing + ImageVariation: false, // Gemini does not support image variation + ImageVariationStream: false, // Gemini does not support streaming image variation + ListModels: true, + BatchCreate: true, + BatchList: true, + BatchRetrieve: true, + BatchCancel: true, + BatchResults: true, + FileUpload: true, + FileList: true, + FileRetrieve: true, + FileDelete: true, + FileContent: false, // Gemini doesn't support direct content download }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1318,31 +1358,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "openai/gpt-4o", TextModel: "google/gemini-2.5-flash", Scenarios: TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // OpenRouter does not support image editing - ImageEditStream: false, // OpenRouter does not support streaming image editing - ImageVariation: false, // OpenRouter does not support image variation - ImageVariationStream: false, // OpenRouter does not support streaming image variation - SpeechSynthesis: false, - SpeechSynthesisStream: false, - Transcription: false, - TranscriptionStream: false, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // OpenRouter does not support image editing + ImageEditStream: false, // OpenRouter does not support streaming image editing + ImageVariation: false, // OpenRouter does not support image variation + ImageVariationStream: false, // OpenRouter does not support streaming image variation + SpeechSynthesis: false, + SpeechSynthesisStream: false, + Transcription: false, + TranscriptionStream: false, + Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1396,31 +1437,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "", // XAI focuses on chat ImageGenerationModel: "grok-2-image", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, // Not supported - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: false, // XAI does not support image editing - ImageEditStream: false, // XAI does not support streaming image editing - ImageVariation: false, // XAI does not support image variation - ImageVariationStream: false, // XAI does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, // Not supported + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: false, // XAI does not support image editing + ImageEditStream: false, // XAI does not support streaming image editing + ImageVariation: false, // XAI does not support image variation + ImageVariationStream: false, // XAI does not support streaming image variation + ListModels: true, }, }, { @@ -1429,27 +1471,28 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "openai/gpt-4.1-mini", ImageGenerationModel: "black-forest-labs/flux-dev", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, // Not supported - ListModels: true, - ImageGeneration: true, - ImageGenerationStream: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, // Not supported + ListModels: true, + ImageGeneration: true, + ImageGenerationStream: false, }, }, { Provider: schemas.VLLM, @@ -1458,27 +1501,28 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ EmbeddingModel: "Qwen/Qwen3-Embedding-0.6B", TranscriptionModel: "openai/whisper-small", Scenarios: TestScenarios{ - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: true, // VLLM supports transcription - TranscriptionStream: true, // VLLM supports transcription streaming - Embedding: true, // VLLM supports embedding - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // VLLM does not support image editing - ImageEditStream: false, // VLLM does not support streaming image editing - ImageVariation: false, // VLLM does not support image variation - ImageVariationStream: false, // VLLM does not support streaming image variation - ListModels: true, - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: true, // VLLM supports transcription + TranscriptionStream: true, // VLLM supports transcription streaming + Embedding: true, // VLLM supports embedding + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // VLLM does not support image editing + ImageEditStream: false, // VLLM does not support streaming image editing + ImageVariation: false, // VLLM does not support image variation + ImageVariationStream: false, // VLLM does not support streaming image variation + ListModels: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, + End2EndToolCalling: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, diff --git a/core/internal/llmtests/complete_end_to_end.go b/core/internal/llmtests/complete_end_to_end.go index 607ca1ca86..399c79845a 100644 --- a/core/internal/llmtests/complete_end_to_end.go +++ b/core/internal/llmtests/complete_end_to_end.go @@ -70,7 +70,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C ToolChoice: &schemas.ChatToolChoice{ ChatToolChoiceStr: bifrost.Ptr(string(schemas.ChatToolChoiceTypeRequired)), }, - MaxCompletionTokens: bifrost.Ptr(150), + MaxCompletionTokens: bifrost.Ptr(400), }, Fallbacks: testConfig.Fallbacks, } @@ -88,7 +88,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C ToolChoice: &schemas.ResponsesToolChoice{ ResponsesToolChoiceStr: bifrost.Ptr(string(schemas.ResponsesToolChoiceTypeRequired)), }, - MaxOutputTokens: bifrost.Ptr(150), + MaxOutputTokens: bifrost.Ptr(400), }, } return client.ResponsesRequest(bfCtx, responsesReq) @@ -205,7 +205,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: testConfig.ChatModel, Input: chatConversationHistory, Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(200), + MaxCompletionTokens: bifrost.Ptr(400), }, Fallbacks: testConfig.Fallbacks, } @@ -219,7 +219,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: testConfig.ChatModel, Input: responsesConversationHistory, Params: &schemas.ResponsesParameters{ - MaxOutputTokens: bifrost.Ptr(200), + MaxOutputTokens: bifrost.Ptr(400), }, } return client.ResponsesRequest(bfCtx, responsesReq) @@ -343,7 +343,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: model, Input: chatConversationHistory, Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(200), + MaxCompletionTokens: bifrost.Ptr(400), }, Fallbacks: testConfig.Fallbacks, } @@ -357,7 +357,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: model, Input: responsesConversationHistory, Params: &schemas.ResponsesParameters{ - MaxOutputTokens: bifrost.Ptr(200), + MaxOutputTokens: bifrost.Ptr(400), }, } return client.ResponsesRequest(bfCtx, responsesReq) diff --git a/core/internal/llmtests/image_edit.go b/core/internal/llmtests/image_edit.go index 56ad66d502..deed0bd820 100644 --- a/core/internal/llmtests/image_edit.go +++ b/core/internal/llmtests/image_edit.go @@ -364,8 +364,8 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context t.Error("❌ ExtraFields.Provider is empty") } - if imageEditResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageEditResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageEditRequest @@ -374,7 +374,7 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Logf("✅ Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.ModelRequested, len(imageEditResponse.Data)) + imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data)) }) } diff --git a/core/internal/llmtests/image_generation.go b/core/internal/llmtests/image_generation.go index 81a0626978..1516ff0088 100644 --- a/core/internal/llmtests/image_generation.go +++ b/core/internal/llmtests/image_generation.go @@ -145,12 +145,12 @@ func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C t.Error("❌ ExtraFields.Provider is empty") } - if imageGenerationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageGenerationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } t.Logf("✅ Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.ModelRequested, len(imageGenerationResponse.Data)) + imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.OriginalModelRequested, len(imageGenerationResponse.Data)) }) } diff --git a/core/internal/llmtests/image_variation.go b/core/internal/llmtests/image_variation.go index 0aca33a63f..d0c4d18e78 100644 --- a/core/internal/llmtests/image_variation.go +++ b/core/internal/llmtests/image_variation.go @@ -162,8 +162,8 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co t.Error("❌ ExtraFields.Provider is empty") } - if imageVariationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageVariationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageVariationRequest @@ -172,7 +172,7 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } t.Logf("✅ Image variation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.ModelRequested, len(imageVariationResponse.Data)) + imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.OriginalModelRequested, len(imageVariationResponse.Data)) }) } diff --git a/core/internal/llmtests/list_models.go b/core/internal/llmtests/list_models.go index f0b5bcf7c5..3c2133aef5 100644 --- a/core/internal/llmtests/list_models.go +++ b/core/internal/llmtests/list_models.go @@ -9,6 +9,17 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// listModelsBifrostContext returns a context for ListModels. For Replicate, sets BifrostContextKeyDirectKey +// so only the deployments key is used (see replicateProviderTestKeys in account.go). That key must not use an +// empty Models allowlist, or ListModelsPipeline.ShouldEarlyExit returns no models before the API runs. +func listModelsBifrostContext(parent context.Context, provider schemas.ModelProvider) *schemas.BifrostContext { + bfCtx := schemas.NewBifrostContext(parent, schemas.NoDeadline) + if provider == schemas.Replicate { + bfCtx.SetValue(schemas.BifrostContextKeyDirectKey, ReplicateDirectKeyForListModels()) + } + return bfCtx +} + // RunListModelsTest executes the list models test scenario func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { if !testConfig.Scenarios.ListModels { @@ -59,7 +70,7 @@ func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModels", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, request) }) @@ -154,7 +165,7 @@ func RunListModelsResponseMarshalTest(t *testing.T, client *bifrost.Bifrost, ctx } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsResponseMarshal", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, request) }) @@ -293,7 +304,7 @@ func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx cont } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsPagination", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, request) }) @@ -336,7 +347,7 @@ func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx cont } nextPageResponse, nextPageErr := WithListModelsTestRetry(t, listModelsRetryConfig, nextPageRetryContext, expectations, "ListModelsPagination_NextPage", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, nextPageRequest) }) diff --git a/core/internal/llmtests/realtime.go b/core/internal/llmtests/realtime.go index 821aeba9eb..400f5f9cda 100644 --- a/core/internal/llmtests/realtime.go +++ b/core/internal/llmtests/realtime.go @@ -43,7 +43,7 @@ func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.RealtimeModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/llmtests/response_validation.go b/core/internal/llmtests/response_validation.go index bcc419abd3..bc75dd07df 100644 --- a/core/internal/llmtests/response_validation.go +++ b/core/internal/llmtests/response_validation.go @@ -859,7 +859,7 @@ func validateResponsesBasicStructure(response *schemas.BifrostResponsesResponse, } provider := response.ExtraFields.Provider - model := response.ExtraFields.ModelDeployment + model := response.ExtraFields.ResolvedModelUsed // Verify top level status is present for OpenAI and Azure with non-Claude models if provider != "" && (provider == schemas.OpenAI || provider == schemas.Azure) && !strings.Contains(strings.ToLower(model), "claude") { @@ -988,8 +988,7 @@ func validateResponsesTechnicalFields(t *testing.T, response *schemas.BifrostRes // Check model field if expectations.ShouldHaveModel { - if strings.TrimSpace(response.Model) == "" && - strings.TrimSpace(response.ExtraFields.ModelDeployment) == "" { + if strings.TrimSpace(response.Model) == "" { result.Passed = false result.Errors = append(result.Errors, fmt.Sprintf("Expected model field but not present or empty (provider: %s)", response.ExtraFields.Provider)) } diff --git a/core/internal/llmtests/speech_synthesis.go b/core/internal/llmtests/speech_synthesis.go index 4e08d6e2c8..aae66423a3 100644 --- a/core/internal/llmtests/speech_synthesis.go +++ b/core/internal/llmtests/speech_synthesis.go @@ -239,8 +239,8 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c t.Fatalf("HD audio data too small: got %d bytes, expected at least 5000", audioSize) } - if speechResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.ModelRequested) + if speechResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.OriginalModelRequested) } t.Logf("✅ HD speech synthesis successful: %d bytes generated", len(speechResponse.Audio)) @@ -344,8 +344,8 @@ func validateSpeechSynthesisSpecific(t *testing.T, response *schemas.BifrostSpee t.Fatalf("Audio data too small: got %d bytes, expected at least %d", audioSize, expectMinBytes) } - if expectedModel != "" && response.ExtraFields.ModelRequested != expectedModel { - t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.ModelRequested) + if expectedModel != "" && response.ExtraFields.OriginalModelRequested != expectedModel { + t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.OriginalModelRequested) } t.Logf("✅ Audio validation passed: %d bytes generated", audioSize) diff --git a/core/internal/llmtests/speech_synthesis_stream.go b/core/internal/llmtests/speech_synthesis_stream.go index 87268f3c17..8b7bdc8efb 100644 --- a/core/internal/llmtests/speech_synthesis_stream.go +++ b/core/internal/llmtests/speech_synthesis_stream.go @@ -184,8 +184,8 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type) } - if response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } @@ -348,8 +348,8 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, t.Logf("✅ HD chunk %d: %d bytes", chunkCount, chunkSize) } - if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } diff --git a/core/internal/llmtests/transcription_stream.go b/core/internal/llmtests/transcription_stream.go index dfc80fc533..a28239c00f 100644 --- a/core/internal/llmtests/transcription_stream.go +++ b/core/internal/llmtests/transcription_stream.go @@ -242,8 +242,12 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type) } - if response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != testConfig.TranscriptionModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested) + gotModel := response.BifrostTranscriptionStreamResponse.ExtraFields.OriginalModelRequested + if gotModel == "" { + t.Fatal("❌ Stream chunk missing extra_fields.original_model_requested") + } + if gotModel != testConfig.TranscriptionModel { + t.Fatalf("❌ Unexpected original_model_requested in stream: got %q want %q", gotModel, testConfig.TranscriptionModel) } lastResponse = DeepCopyBifrostStreamChunk(response) diff --git a/core/internal/llmtests/video.go b/core/internal/llmtests/video.go index c622edf6b4..8ac2d6e396 100644 --- a/core/internal/llmtests/video.go +++ b/core/internal/llmtests/video.go @@ -48,8 +48,8 @@ func RunVideoGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C if resp.ExtraFields.Provider == "" { t.Fatal("❌ Video generation extra_fields.provider is empty") } - if resp.ExtraFields.ModelRequested == "" { - t.Fatal("❌ Video generation extra_fields.model_requested is empty") + if resp.ExtraFields.OriginalModelRequested == "" { + t.Fatal("❌ Video generation extra_fields.original_model_requested is empty") } t.Logf("✅ Video generation created job: id=%s status=%s", resp.ID, resp.Status) diff --git a/core/internal/llmtests/websocket_responses.go b/core/internal/llmtests/websocket_responses.go index 420a049fb7..966463dade 100644 --- a/core/internal/llmtests/websocket_responses.go +++ b/core/internal/llmtests/websocket_responses.go @@ -38,7 +38,7 @@ func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx contex bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.ChatModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/mcptests/agent_test_helpers.go b/core/internal/mcptests/agent_test_helpers.go index d19d953ca0..85512dcce6 100644 --- a/core/internal/mcptests/agent_test_helpers.go +++ b/core/internal/mcptests/agent_test_helpers.go @@ -131,11 +131,11 @@ func SetupAgentTest(t *testing.T, config AgentTestConfig) (*mcp.MCPManager, *Dyn // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -192,11 +192,11 @@ func SetupAgentTestWithClients(t *testing.T, config AgentTestConfig, customClien // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) diff --git a/core/internal/mcptests/codemode_stdio_test.go b/core/internal/mcptests/codemode_stdio_test.go index 8fe5841a82..aab3a15172 100644 --- a/core/internal/mcptests/codemode_stdio_test.go +++ b/core/internal/mcptests/codemode_stdio_test.go @@ -56,27 +56,27 @@ func setupCodeModeWithSTDIOServers(t *testing.T, serverNames ...string) (*mcp.MC config = GetTemperatureMCPClientConfig(bifrostRoot) config.IsCodeModeClient = true config.ID = "temperature-client" // Match test expectations - config.Name = "temperature" // Use lowercase to match test code + config.Name = "temperature" // Use lowercase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "go-test-server": config = GetGoTestServerConfig(bifrostRoot) config.ID = "goTestServer-client" // Match test expectations - config.Name = "goTestServer" // Use camelCase to match test code + config.Name = "goTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "edge-case-server": config = GetEdgeCaseServerConfig(bifrostRoot) config.ID = "edgeCaseServer-client" // Match test expectations - config.Name = "edgeCaseServer" // Use camelCase to match test code + config.Name = "edgeCaseServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "error-test-server": config = GetErrorTestServerConfig(bifrostRoot) config.ID = "errorTestServer-client" // Match test expectations - config.Name = "errorTestServer" // Use camelCase to match test code + config.Name = "errorTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "parallel-test-server": config = GetParallelTestServerConfig(bifrostRoot) config.ID = "parallelTestServer-client" // Match test expectations - config.Name = "parallelTestServer" // Use camelCase to match test code + config.Name = "parallelTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "test-tools-server": // test-tools-server doesn't have a fixture, set up manually @@ -367,9 +367,9 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { expectedError string }{ { - name: "allow_only_test_tools_server", - includeClients: []string{"testToolsServer"}, - code: `result = testToolsServer.echo(message="allowed")`, + name: "allow_only_test_tools_server", + includeClients: []string{"testToolsServer"}, + code: `result = testToolsServer.echo(message="allowed")`, shouldSucceed: true, expectedInResult: "allowed", }, @@ -377,13 +377,13 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_test_tools_server", includeClients: []string{"temperature"}, code: `result = testToolsServer.echo(message="blocked")`, - shouldSucceed: false, - expectedError: "undefined: testToolsServer", + shouldSucceed: false, + expectedError: "undefined: testToolsServer", }, { - name: "allow_only_temperature_server", - includeClients: []string{"temperature"}, - code: `result = temperature.get_temperature(location="Paris")`, + name: "allow_only_temperature_server", + includeClients: []string{"temperature"}, + code: `result = temperature.get_temperature(location="Paris")`, shouldSucceed: true, expectedInResult: "Paris", }, @@ -391,8 +391,8 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_temperature_server", includeClients: []string{"testToolsServer"}, code: `result = temperature.get_temperature(location="blocked")`, - shouldSucceed: false, - expectedError: "undefined: temperature", + shouldSucceed: false, + expectedError: "undefined: temperature", }, { name: "allow_both_servers", @@ -409,7 +409,7 @@ result = {"echo": echo, "temp": temp}`, t.Run(tc.name, func(t *testing.T) { // Create context with client filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied at tool listing level @@ -524,7 +524,7 @@ result = {"echo": echo, "calc": calc}`, t.Run(tc.name, func(t *testing.T) { // Create context with tool filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied @@ -622,10 +622,10 @@ result = {"echo": echo, "temp": temp}`, // Create context with both client and tool filtering baseCtx := context.Background() if tc.includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) } if tc.includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -1692,7 +1692,7 @@ result = {"count": 3}`, for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) toolCall := CreateExecuteToolCodeCall(fmt.Sprintf("call-%s", tc.name), tc.code) diff --git a/core/internal/mcptests/concurrency_advanced_test.go b/core/internal/mcptests/concurrency_advanced_test.go index a1c3823831..e3c5793df4 100644 --- a/core/internal/mcptests/concurrency_advanced_test.go +++ b/core/internal/mcptests/concurrency_advanced_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -533,14 +532,14 @@ func TestConcurrent_FilteringChanges(t *testing.T) { if id%2 == 0 { // Even: allow all tools baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } else { // Odd: allow only echo baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/internal/mcptests/fixtures.go b/core/internal/mcptests/fixtures.go index 88b00a9f70..f760ae5ac0 100644 --- a/core/internal/mcptests/fixtures.go +++ b/core/internal/mcptests/fixtures.go @@ -1422,7 +1422,7 @@ func (a *testAccount) GetKeysForProvider(ctx context.Context, providerKey schema return []schemas.Key{ { Value: *schemas.NewEnvVar(apiKey), - Models: []string{}, // Empty means all models + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, nil @@ -1460,6 +1460,17 @@ func setupBifrost(t *testing.T) *bifrost.Bifrost { return bifrostInstance } +// noopPluginPipeline is a passthrough pipeline used in tests that don't need plugin hooks. +type noopPluginPipeline struct{} + +func (n *noopPluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int) { + return req, nil, 0 +} + +func (n *noopPluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError) { + return mcpResp, bifrostErr +} + // setupMCPManager creates an MCP manager for testing func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mcp.MCPManager { t.Helper() @@ -1472,9 +1483,14 @@ func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mc clientConfigPtrs[i] = &clientConfigs[i] } - // Create MCP config + // Create MCP config with a no-op plugin pipeline so that codemode tool calls + // work correctly even when no Bifrost instance is attached. mcpConfig := &schemas.MCPConfig{ ClientConfigs: clientConfigPtrs, + PluginPipelineProvider: func() interface{} { + return &noopPluginPipeline{} + }, + ReleasePluginPipeline: func(pipeline interface{}) {}, } // Create Starlark CodeMode @@ -1984,10 +2000,10 @@ func AssertExecutionTimeUnder(t *testing.T, fn func(), maxDuration time.Duration func CreateTestContextWithMCPFilter(includeClients []string, includeTools []string) *schemas.BifrostContext { baseCtx := context.Background() if includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, includeClients) } if includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, includeTools) } return schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/internal/mcptests/tool_filtering_test.go b/core/internal/mcptests/tool_filtering_test.go index eb8b370a28..15fde03d75 100644 --- a/core/internal/mcptests/tool_filtering_test.go +++ b/core/internal/mcptests/tool_filtering_test.go @@ -160,7 +160,7 @@ func TestToolsToExecute_ExplicitList(t *testing.T) { // Verify configuration was set correctly clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) } func TestToolsToExecute_SingleTool(t *testing.T) { @@ -178,10 +178,10 @@ func TestToolsToExecute_SingleTool(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) // Verify it's not allow-all - assert.NotEqual(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute, "should not be wildcard") + assert.NotEqual(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute, "should not be wildcard") } // ============================================================================= @@ -204,8 +204,8 @@ func TestToolsToAutoExecute_Basic(t *testing.T) { // Verify the client was created with correct configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToAutoExecute) } func TestToolsToAutoExecute_NotInExecuteList(t *testing.T) { @@ -224,8 +224,8 @@ func TestToolsToAutoExecute_NotInExecuteList(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.Equal(t, []string{"hash"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"hash"}, clients[0].ExecutionConfig.ToolsToAutoExecute) assert.NotEqual(t, clients[0].ExecutionConfig.ToolsToExecute, clients[0].ExecutionConfig.ToolsToAutoExecute) } @@ -245,7 +245,7 @@ func TestToolsToAutoExecute_Wildcard(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToAutoExecute) } // ============================================================================= @@ -267,7 +267,7 @@ func TestContextFilteringRestrictsWildcard(t *testing.T) { // Verify client configuration allows all clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute) // Context restricts to only specific tools (verify context works separately) ctx := CreateTestContextWithMCPFilter(nil, []string{"encode"}) @@ -305,9 +305,9 @@ func TestFilteringMultipleClients_DifferentRules(t *testing.T) { // Find and verify each client for _, client := range clients { if client.ExecutionConfig.ID == "stdio-client-1" { - assert.Equal(t, []string{"encode"}, client.ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, client.ExecutionConfig.ToolsToExecute) } else if client.ExecutionConfig.ID == "stdio-client-2" { - assert.Equal(t, []string{"*"}, client.ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"*"}, client.ExecutionConfig.ToolsToExecute) } } } @@ -331,7 +331,7 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Verify initial configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) // Edit client to only allow second tool clientConfig.ToolsToExecute = []string{"hash"} @@ -341,6 +341,6 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Verify configuration changed clients = manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"hash"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.NotEqual(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"hash"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.NotEqual(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) } diff --git a/core/keyselectors/weightedrandom.go b/core/keyselectors/weightedrandom.go new file mode 100644 index 0000000000..a3b8994d0d --- /dev/null +++ b/core/keyselectors/weightedrandom.go @@ -0,0 +1,35 @@ +package keyselectors + +import ( + "math/rand" + + "github.com/maximhq/bifrost/core/schemas" +) + +func WeightedRandom(ctx *schemas.BifrostContext, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { + // Use a weighted random selection based on key weights + totalWeight := 0 + for _, key := range keys { + totalWeight += int(key.Weight * 100) // Convert float to int for better performance + } + + // If all keys have zero weight, fall back to uniform random selection + if totalWeight == 0 { + return keys[rand.Intn(len(keys))], nil + } + + // Use global thread-safe random (Go 1.20+) - no allocation, no syscall + randomValue := rand.Intn(totalWeight) + + // Select key based on weight + currentWeight := 0 + for _, key := range keys { + currentWeight += int(key.Weight * 100) + if randomValue < currentWeight { + return key, nil + } + } + + // Fallback to first key if something goes wrong + return keys[0], nil +} diff --git a/core/mcp/agent.go b/core/mcp/agent.go index fe4481ad7a..96d16ec24e 100644 --- a/core/mcp/agent.go +++ b/core/mcp/agent.go @@ -1,6 +1,7 @@ package mcp import ( + "errors" "fmt" "strings" "sync" @@ -10,7 +11,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) - type AgentModeExecutor struct { logger schemas.Logger } @@ -40,7 +40,7 @@ func (a *AgentModeExecutor) ExecuteAgentForChatRequest( makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Create adapter for Chat API adapter := &chatAPIAdapter{ @@ -143,7 +143,7 @@ func (a *AgentModeExecutor) executeAgent( adapter agentAPIAdapter, fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (interface{}, *schemas.BifrostError) { // Get initial response from adapter currentResponse := adapter.getInitialResponse() @@ -157,6 +157,9 @@ func (a *AgentModeExecutor) executeAgent( allExecutedToolResults := make([]*schemas.ChatMessage, 0) allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + // Accumulate token usage across all LLM calls in the agent loop + accumulatedUsage := adapter.extractUsage(currentResponse) + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) if ok { ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID) @@ -207,14 +210,8 @@ func (a *AgentModeExecutor) executeAgent( continue } - // Step 1: Convert literal \n escape sequences to actual newlines for parsing - codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") - if len(codeWithNewlines) != len(code) { - a.logger.Debug("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix) - } - - // Step 2: Extract tool calls from code during AST formation - extractedToolCalls, err := extractToolCallsFromCode(codeWithNewlines) + // Step 1: Extract tool calls from the original source code during validation + extractedToolCalls, err := extractToolCallsFromCode(code) if err != nil { a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) @@ -289,6 +286,8 @@ func (a *AgentModeExecutor) executeAgent( wg := sync.WaitGroup{} wg.Add(len(autoExecutableTools)) channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools)) + var authRequiredErr *schemas.MCPUserOAuthRequiredError + var authRequiredOnce sync.Once for _, toolCall := range autoExecutableTools { go func(toolCall schemas.ChatAssistantMessageToolCall) { defer wg.Done() @@ -305,6 +304,15 @@ func (a *AgentModeExecutor) executeAgent( mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest) if toolErr != nil { + // Check if this is a per-user OAuth auth-required error + var oauthErr *schemas.MCPUserOAuthRequiredError + if errors.As(toolErr, &oauthErr) { + authRequiredOnce.Do(func() { + authRequiredErr = oauthErr + }) + channelToolResults <- createToolResultMessage(toolCall, "", toolErr) + return + } a.logger.Warn("Tool execution failed: %v", toolErr) channelToolResults <- createToolResultMessage(toolCall, "", toolErr) } else if mcpResponse != nil && mcpResponse.ChatMessage != nil { @@ -321,6 +329,23 @@ func (a *AgentModeExecutor) executeAgent( wg.Wait() close(channelToolResults) + // If any tool required per-user OAuth, stop the agent loop and return the error + if authRequiredErr != nil { + statusCode := 401 + errType := "mcp_auth_required" + return nil, &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: authRequiredErr.Message, + Type: &errType, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + MCPAuthRequired: authRequiredErr, + }, + } + } + // Collect tool results executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools)) for toolResult := range channelToolResults { @@ -342,6 +367,8 @@ func (a *AgentModeExecutor) executeAgent( if depth == 1 && len(allExecutedToolResults) == 0 { return currentResponse, nil } + // Apply accumulated usage before building the final response + adapter.applyUsage(currentResponse, accumulatedUsage) // Create response with all executed tool results from all iterations, and non-auto-executable tool calls return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil } @@ -364,11 +391,127 @@ func (a *AgentModeExecutor) executeAgent( } currentResponse = response + accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse)) } + adapter.applyUsage(currentResponse, accumulatedUsage) return currentResponse, nil } +// mergeUsage sums token counts and costs from two BifrostLLMUsage values. +// Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is. +func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage { + if add == nil { + return base + } + if base == nil { + return add + } + + merged := &schemas.BifrostLLMUsage{ + PromptTokens: base.PromptTokens + add.PromptTokens, + CompletionTokens: base.CompletionTokens + add.CompletionTokens, + TotalTokens: base.TotalTokens + add.TotalTokens, + } + + // Merge prompt token details + if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil { + bd := base.PromptTokensDetails + ad := add.PromptTokensDetails + if bd == nil { + bd = &schemas.ChatPromptTokensDetails{} + } + if ad == nil { + ad = &schemas.ChatPromptTokensDetails{} + } + merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: bd.TextTokens + ad.TextTokens, + AudioTokens: bd.AudioTokens + ad.AudioTokens, + ImageTokens: bd.ImageTokens + ad.ImageTokens, + CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens, + CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens, + } + } + + // Merge completion token details + if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil { + bd := base.CompletionTokensDetails + ad := add.CompletionTokensDetails + if bd == nil { + bd = &schemas.ChatCompletionTokensDetails{} + } + if ad == nil { + ad = &schemas.ChatCompletionTokensDetails{} + } + merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: bd.TextTokens + ad.TextTokens, + AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens, + AudioTokens: bd.AudioTokens + ad.AudioTokens, + ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens, + RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens, + } + if bd.CitationTokens != nil || ad.CitationTokens != nil { + bct := 0 + act := 0 + if bd.CitationTokens != nil { + bct = *bd.CitationTokens + } + if ad.CitationTokens != nil { + act = *ad.CitationTokens + } + sum := bct + act + merged.CompletionTokensDetails.CitationTokens = &sum + } + if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil { + bnsq := 0 + ansq := 0 + if bd.NumSearchQueries != nil { + bnsq = *bd.NumSearchQueries + } + if ad.NumSearchQueries != nil { + ansq = *ad.NumSearchQueries + } + sum := bnsq + ansq + merged.CompletionTokensDetails.NumSearchQueries = &sum + } + if bd.ImageTokens != nil || ad.ImageTokens != nil { + bit := 0 + ait := 0 + if bd.ImageTokens != nil { + bit = *bd.ImageTokens + } + if ad.ImageTokens != nil { + ait = *ad.ImageTokens + } + sum := bit + ait + merged.CompletionTokensDetails.ImageTokens = &sum + } + } + + // Merge cost + if base.Cost != nil || add.Cost != nil { + bc := base.Cost + ac := add.Cost + if bc == nil { + bc = &schemas.BifrostCost{} + } + if ac == nil { + ac = &schemas.BifrostCost{} + } + merged.Cost = &schemas.BifrostCost{ + InputTokensCost: bc.InputTokensCost + ac.InputTokensCost, + OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost, + ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost, + CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost, + SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost, + RequestCost: bc.RequestCost + ac.RequestCost, + TotalCost: bc.TotalCost + ac.TotalCost, + } + } + + return merged +} + // extractToolCalls extracts all tool calls from a chat response. // It iterates through all choices in the response and collects tool calls // from assistant messages. @@ -460,25 +603,23 @@ func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager C // Get auto-executable tools from config toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute - if len(toolsToAutoExecute) == 0 { + if toolsToAutoExecute.IsEmpty() { // No auto-executable tools configured for this client continue } // Parse tool names (as they appear in JavaScript code) autoExecutableTools := []string{} - for _, originalToolName := range toolsToAutoExecute { - // Handle wildcard "*" - means all tools are auto-executable - if originalToolName == "*" { - autoExecutableTools = append(autoExecutableTools, "*") - continue + if toolsToAutoExecute.IsUnrestricted() { + autoExecutableTools = append(autoExecutableTools, "*") + } else { + for _, originalToolName := range toolsToAutoExecute { + // Replace - with _ for code mode compatibility, then parse for JS compatibility + toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") + parsedToolName := parseToolName(toolNameForCode) + autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Replace - with _ for code mode compatibility, then parse for JS compatibility - toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") - parsedToolName := parseToolName(toolNameForCode) - autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Add to map if there are auto-executable tools if len(autoExecutableTools) > 0 { allowedTools[clientName] = autoExecutableTools diff --git a/core/mcp/agentadaptors.go b/core/mcp/agentadaptors.go index 6986cd9798..7b78df4389 100644 --- a/core/mcp/agentadaptors.go +++ b/core/mcp/agentadaptors.go @@ -59,6 +59,12 @@ type agentAPIAdapter interface { executedToolCalls []schemas.ChatAssistantMessageToolCall, nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, ) interface{} + + // extractUsage returns the token usage from a response as BifrostLLMUsage. + extractUsage(response interface{}) *schemas.BifrostLLMUsage + + // applyUsage sets accumulated usage on the response in place. + applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) } // chatAPIAdapter implements agentAPIAdapter for Chat API @@ -175,6 +181,14 @@ func (c *chatAPIAdapter) createResponseWithExecutedTools( ) } +func (c *chatAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage { + return response.(*schemas.BifrostChatResponse).Usage +} + +func (c *chatAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) { + response.(*schemas.BifrostChatResponse).Usage = usage +} + // createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response // that includes executed tool results and non-auto-executable tool calls. The response // contains a formatted text summary of executed tool results and includes the non-auto-executable @@ -390,6 +404,14 @@ func (r *responsesAPIAdapter) createResponseWithExecutedTools( ) } +func (r *responsesAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage { + return response.(*schemas.BifrostResponsesResponse).Usage.ToBifrostLLMUsage() +} + +func (r *responsesAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) { + response.(*schemas.BifrostResponsesResponse).Usage = usage.ToResponsesResponseUsage() +} + // createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response // that includes executed tool results and non-auto-executable tool calls. The response // contains a formatted text summary of executed tool results and includes the non-auto-executable diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go index 36b12243da..b6bd442c20 100644 --- a/core/mcp/clientmanager.go +++ b/core/mcp/clientmanager.go @@ -118,6 +118,33 @@ func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { // This is to avoid deadlocks when the connection attempt is made m.mu.Unlock() + // Per-user OAuth: skip persistent connection. Auth is per-request at runtime. + // The admin verifies the configuration via a sample login before this is called, + // and tools are populated separately via SetClientTools(). + if configCopy.AuthType == schemas.MCPAuthTypePerUserOauth { + m.mu.Lock() + if client, exists := m.clientMap[config.ID]; exists { + if config.ConnectionString != nil { + url := config.ConnectionString.GetValue() + client.ConnectionInfo.ConnectionURL = &url + } + // Restore discovered tools from config (persisted in DB across restarts) + if len(config.DiscoveredTools) > 0 { + for toolName, tool := range config.DiscoveredTools { + client.ToolMap[toolName] = tool + } + client.ToolNameMapping = config.DiscoveredToolNameMapping + client.State = schemas.MCPConnectionStateConnected + m.logger.Info("%s Per-user OAuth MCP client '%s' restored with %d tools", MCPLogPrefix, config.Name, len(config.DiscoveredTools)) + } else { + client.State = schemas.MCPConnectionStatePendingTools + m.logger.Info("%s Per-user OAuth MCP client '%s' registered (connection deferred to runtime)", MCPLogPrefix, config.Name) + } + } + m.mu.Unlock() + return nil + } + // Connect using the copied config if err := m.connectToMCPClient(configCopy); err != nil { // Clean up the failed entry — this is a user-initiated action (UI/API), @@ -131,6 +158,92 @@ func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { return nil } +// VerifyPerUserOAuthConnection creates a temporary MCP connection using the +// provided access token to verify the server is reachable and discover available +// tools. The connection is closed after verification. This is used during +// per-user OAuth client setup when the admin does a test login to validate the +// OAuth configuration before saving the MCP client. +// +// Parameters: +// - config: MCP client configuration (connection URL, name, etc.) +// - accessToken: temporary OAuth access token from the admin's test login +// +// Returns: +// - map[string]schemas.ChatTool: discovered tools keyed by prefixed name +// - map[string]string: tool name mapping (sanitized → original MCP name) +// - error: any error during verification +func (m *MCPManager) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { + if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" { + return nil, nil, fmt.Errorf("connection URL is required for per-user OAuth verification") + } + + // Create HTTP transport with the admin's temporary Bearer token + headers := map[string]string{ + "Authorization": "Bearer " + accessToken, + } + httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) + if err != nil { + return nil, nil, fmt.Errorf("failed to create HTTP transport for verification: %w", err) + } + + // Create temporary MCP client + tempClient := client.NewClient(httpTransport) + ctx, cancel := context.WithTimeout(ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + + // Start transport + if err := tempClient.Start(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to start MCP connection for verification: %w", err) + } + defer tempClient.Close() + + // Initialize MCP handshake + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s-verify", config.Name), + Version: "1.0.0", + }, + }, + } + if _, err := tempClient.Initialize(ctx, initRequest); err != nil { + return nil, nil, fmt.Errorf("failed to initialize MCP connection for verification: %w", err) + } + + // Discover tools + tools, toolNameMapping, err := retrieveExternalTools(ctx, tempClient, config.Name, m.logger) + if err != nil { + return nil, nil, fmt.Errorf("failed to discover tools during verification: %w", err) + } + + m.logger.Info("%s Per-user OAuth verification succeeded for '%s': discovered %d tools", MCPLogPrefix, config.Name, len(tools)) + return tools, toolNameMapping, nil +} + +// SetClientTools updates the tool map and name mapping for an existing client. +// This is used to populate tools discovered during per-user OAuth verification, +// where tool discovery happens separately from client creation. +// +// Parameters: +// - clientID: ID of the client to update +// - tools: discovered tools keyed by prefixed name +// - toolNameMapping: mapping from sanitized tool names to original MCP names +func (m *MCPManager) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { + m.mu.Lock() + defer m.mu.Unlock() + + if client, exists := m.clientMap[clientID]; exists { + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + client.ToolNameMapping = toolNameMapping + client.State = schemas.MCPConnectionStateConnected + m.logger.Debug("%s Set %d tools on client '%s'", MCPLogPrefix, len(tools), client.Name) + } +} + // RemoveClient removes an MCP client from the manager. // It handles cleanup for all transport types (HTTP, STDIO, SSE). // @@ -243,13 +356,15 @@ func (m *MCPManager) UpdateClient(id string, updatedConfig *schemas.MCPClientCon ConfigHash: client.ExecutionConfig.ConfigHash, ToolPricing: maps.Clone(client.ExecutionConfig.ToolPricing), // Updatable fields - copy from updated config with proper cloning - Name: updatedConfig.Name, - IsCodeModeClient: updatedConfig.IsCodeModeClient, - Headers: maps.Clone(updatedConfig.Headers), - ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), - ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), - IsPingAvailable: updatedConfig.IsPingAvailable, - ToolSyncInterval: updatedConfig.ToolSyncInterval, + Name: updatedConfig.Name, + IsCodeModeClient: updatedConfig.IsCodeModeClient, + Headers: maps.Clone(updatedConfig.Headers), + ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), + ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), + AllowedExtraHeaders: slices.Clone(updatedConfig.AllowedExtraHeaders), + IsPingAvailable: updatedConfig.IsPingAvailable, + ToolSyncInterval: updatedConfig.ToolSyncInterval, + AllowOnAllVirtualKeys: updatedConfig.AllowOnAllVirtualKeys, } // Atomically replace the config pointer @@ -663,7 +778,11 @@ func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error { } // Start health monitoring for the client - monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, config.IsPingAvailable, m.logger) + isPingAvailable := true + if config.IsPingAvailable != nil { + isPingAvailable = *config.IsPingAvailable + } + monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, isPingAvailable, m.logger) m.healthMonitorManager.StartMonitoring(monitor) // Start tool syncing for the client (skip for internal bifrost client) diff --git a/core/mcp/codemode.go b/core/mcp/codemode.go index e81c984195..fa11e52d0b 100644 --- a/core/mcp/codemode.go +++ b/core/mcp/codemode.go @@ -3,7 +3,6 @@ package mcp import ( - "context" "sync" "time" @@ -31,7 +30,7 @@ type CodeMode interface { // ExecuteTool handles a code mode tool call by name. // Returns the response message and any error that occurred. - ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) + ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) // IsCodeModeTool returns true if the given tool name is a code mode tool. IsCodeModeTool(toolName string) bool diff --git a/core/mcp/codemode/starlark/executecode.go b/core/mcp/codemode/starlark/executecode.go index da497d8f6f..d2d9435764 100644 --- a/core/mcp/codemode/starlark/executecode.go +++ b/core/mcp/codemode/starlark/executecode.go @@ -5,7 +5,6 @@ package starlark import ( "context" "fmt" - "net/http" "strings" "time" @@ -13,9 +12,11 @@ import ( "github.com/mark3labs/mcp-go/mcp" codemcp "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" + "go.starlark.net/syntax" ) // ExecutionResult represents the result of code execution @@ -52,8 +53,11 @@ type ExecutionEnvironment struct { func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { executeToolCodeProps := schemas.NewOrderedMapFromPairs( schemas.KV("code", map[string]interface{}{ - "type": "string", - "description": "Python code to execute. The code runs in a Starlark interpreter (Python subset). Tool calls are synchronous - no async/await needed. For loops/conditionals, wrap in a function. Use print() for logging. ALWAYS retry if code fails. Example: def main():\n items = server.list_items()\n for item in items:\n print(item)\nresult = main()", + "type": "string", + "description": "Python (Starlark) code to execute. Tool calls are synchronous: result = server.tool(param=\"value\"). " + + "Use print() for logging. Assign to 'result' variable to return a value. " + + "Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations. " + + "Example: items = server.list_items()\nfor item in items:\n print(item[\"name\"])\nresult = items", }), ) return schemas.ChatTool{ @@ -61,36 +65,36 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { Function: &schemas.ChatToolFunction{ Name: codemcp.ToolTypeExecuteToolCode, Description: schemas.Ptr( - "Executes Python code inside a sandboxed Starlark interpreter with access to all connected MCP servers' tools. " + - "All connected servers are exposed as global objects named after their configuration keys, and each server " + - "provides functions for every tool available on that server. The canonical usage pattern is: " + - "result = .(param=\"value\"). Both and should be discovered " + - "using listToolFiles and readToolFile. " + - - "IMPORTANT WORKFLOW: Always follow this order — first use listToolFiles to see available servers and tools, " + - "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + - "to execute your code. " + + "Executes Python code in a sandboxed Starlark interpreter with MCP server tool access. " + + "Servers are exposed as global objects: result = serverName.toolName(param=\"value\"). " + + "This is the final step of the four-tool code mode workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "If you have not already read a tool's .pyi stub in this conversation, do that before writing code. " + + "Do NOT guess callable tool names from natural language or stale assumptions; use the exact identifier returned by listToolFiles/readToolFile. " + + + "STARLARK DIFFERENCES FROM PYTHON — READ BEFORE WRITING CODE: " + + "1. NO try/except/finally/raise — error handling is not supported, and tool failures cannot be caught inside Starlark. " + + "2. NO classes — use dicts and functions. " + + "3. NO imports, direct network access, or direct filesystem access — use MCP tools instead. " + + "4. NO is operator — use == for comparison. " + + "5. NO f-strings — use % formatting: \"Hello %s, count=%d\" % (name, n). " + + "6. Each executeToolCode call runs in a FRESH ISOLATED SCOPE — no variables, functions, or state persist between calls. Re-fetch data or store it via MCP tools (e.g., SQLite, FileSystem) if needed across calls. " + "SYNTAX NOTES: " + - "• Tool calls are synchronous - NO async/await needed, just call directly: result = server.tool(arg=\"value\") " + + "• Synchronous calls — NO async/await: result = server.tool(arg=\"value\") " + "• Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " + "• Access dict values with brackets: result[\"key\"] NOT result.key " + - "• Use print() for logging (not console.log) " + - "• List comprehensions work: [x for x in items if x[\"active\"]] " + - "• To return a value, assign to 'result' variable: result = computed_value " + - "• CRITICAL: for/if/while at top level MUST be inside a function - def main(): ... then result = main() " + - - "RETRY POLICY: ALWAYS retry if a code block fails. Analyze the error, adjust your code, and retry. " + - - "The environment is intentionally minimal: " + - "• No imports needed or supported " + - "• No network APIs (use MCP tools for external interactions) " + - "• No file system access (use MCP tools) " + - "• No classes (use dicts and functions) " + - "• Deterministic execution (no random, no time) " + - - "Long-running operations are interrupted via execution timeout. " + - "This tool is designed specifically for orchestrating MCP tool calls and lightweight computation.", + "• Use print() for logging/debugging " + + "• List comprehensions: [x for x in items if x[\"active\"]] " + + "• String escapes work normally: \"line1\\nline2\" produces a newline " + + "• Triple-quoted strings for multiline: \"\"\"multi\\nline\"\"\" " + + "• chr(10) for newline character, chr(9) for tab " + + "• To return a value, assign to 'result': result = computed_value " + + "• MCP tool calls are timeout-limited; avoid long or infinite loops " + + + "AVAILABLE BUILTINS: print, len, range, enumerate, zip, sorted, reversed, min, max, " + + "int, float, str, bool, list, dict, tuple, set, hasattr, getattr, type, chr, ord, any, all, hash, repr. " + + + "RETRY POLICY: Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations.", ), Parameters: &schemas.ToolFunctionParameters{ @@ -103,7 +107,7 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { } // handleExecuteToolCode handles the executeToolCode tool call. -func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { toolName := "unknown" if toolCall.Function.Name != nil { toolName = *toolCall.Function.Name @@ -197,16 +201,13 @@ func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall s } // executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings. -func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) ExecutionResult { +func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult { logs := []string{} s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix) - // Step 1: Convert literal \n escape sequences to actual newlines - codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") - - // Step 2: Handle empty code - trimmedCode := strings.TrimSpace(codeWithNewlines) + // Step 1: Handle empty code + trimmedCode := strings.TrimSpace(code) if trimmedCode == "" { return ExecutionResult{ Result: nil, @@ -218,7 +219,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } } - // Step 3: Build tool bindings for all connected servers + // Step 2: Build tool bindings for all connected servers availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) serverKeys := make([]string, 0, len(availableToolsPerClient)) predeclared := starlark.StringDict{} @@ -254,9 +255,8 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } originalToolName := tool.Function.Name - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - parsedToolName := parseToolName(unprefixedToolName) + parsedToolName := getCanonicalToolName(clientName, originalToolName) + compatibilityAlias := getCompatibilityToolAlias(clientName, originalToolName) s.logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName) @@ -298,6 +298,13 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi }) structMembers[parsedToolName] = toolFunc + + if compatibilityAlias != parsedToolName && isValidStarlarkIdentifier(compatibilityAlias) { + if _, exists := structMembers[compatibilityAlias]; !exists { + structMembers[compatibilityAlias] = toolFunc + s.logger.Debug("%s [%s] Added compatibility alias: %s -> %s", codemcp.CodeModeLogPrefix, clientName, compatibilityAlias, parsedToolName) + } + } } // Create a struct for this server @@ -312,7 +319,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi s.logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix) } - // Step 4: Create Starlark thread with print function and timeout + // Step 3: Create Starlark thread with print function and timeout toolExecutionTimeout := s.getToolExecutionTimeout() timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) defer cancel() @@ -324,11 +331,26 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi }, } - // Set up cancellation check + // Set up cancellation check — watch the context and cancel the Starlark + // thread so that infinite loops and other long-running scripts are interrupted + // when the execution timeout fires. thread.SetLocal("context", timeoutCtx) + go func() { + <-timeoutCtx.Done() + thread.Cancel(timeoutCtx.Err().Error()) + }() + + // Step 4: Configure Starlark dialect options for a Python-like experience + starlarkOpts := &syntax.FileOptions{ + TopLevelControl: true, // allow if/for/while at top level (not just inside functions) + While: true, // enable while loops + Set: true, // enable set() builtin + GlobalReassign: true, // allow reassignment to top-level names + Recursion: true, // allow recursive functions + } // Step 5: Execute the code - globals, err := starlark.ExecFile(thread, "code.star", trimmedCode, predeclared) + globals, err := starlark.ExecFileOptions(starlarkOpts, thread, "code.star", trimmedCode, predeclared) if err != nil { errorMessage := err.Error() @@ -372,7 +394,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } // callMCPTool calls an MCP tool and returns the result. -func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { +func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { // Get available tools per client availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) @@ -400,29 +422,25 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Strip the client name prefix from tool name before calling MCP server originalToolName := stripClientPrefix(toolName, clientName) - // Get BifrostContext for plugin pipeline - var bifrostCtx *schemas.BifrostContext - var ok bool - if bifrostCtx, ok = ctx.(*schemas.BifrostContext); !ok { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) + if !ok { + originalRequestID = "" } - originalRequestID, _ := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string) - // Generate new request ID for this nested tool call var newRequestID string if s.fetchNewRequestIDFunc != nil { - newRequestID = s.fetchNewRequestIDFunc(bifrostCtx) + newRequestID = s.fetchNewRequestIDFunc(ctx) } else { newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName) } // Create new child context - deadline, hasDeadline := bifrostCtx.Deadline() + deadline, hasDeadline := ctx.Deadline() if !hasDeadline { deadline = schemas.NoDeadline } - nestedCtx := schemas.NewBifrostContext(bifrostCtx, deadline) + nestedCtx := schemas.NewBifrostContext(ctx, deadline) nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID) if originalRequestID != "" { nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID) @@ -451,13 +469,17 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Check if plugin pipeline is available if s.pluginPipelineProvider == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline provider is nil") } // Get plugin pipeline and run hooks pipeline := s.pluginPipelineProvider() if pipeline == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline is nil") } defer s.releasePluginPipeline(pipeline) @@ -515,14 +537,7 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName Name: toolNameToCall, Arguments: args, }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers + Header: utils.GetHeadersForToolExecution(nestedCtx, client), } toolExecutionTimeout := s.getToolExecutionTimeout() @@ -604,57 +619,3 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName return nil, fmt.Errorf("plugin post-hooks returned invalid response") } - -// callMCPToolDirect executes an MCP tool call directly without plugin hooks. -func (s *StarlarkCodeMode) callMCPToolDirect(ctx context.Context, client *schemas.MCPClientState, originalToolName, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: originalToolName, - Arguments: args, - }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers - } - - toolExecutionTimeout := s.getToolExecutionTimeout() - toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) - defer cancel() - - logToolName := stripClientPrefix(toolName, clientName) - logToolName = strings.ReplaceAll(logToolName, "-", "_") - - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) - if callErr != nil { - s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, logToolName, callErr) - appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, logToolName, callErr)) - return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, logToolName, callErr) - } - - rawResult := extractTextFromMCPResponse(toolResponse, toolName) - - if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { - errorMsg := after - s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, logToolName, errorMsg) - appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, logToolName, errorMsg)) - return nil, fmt.Errorf("%s", errorMsg) - } - - var finalResult interface{} - if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { - finalResult = rawResult - } - - resultStr := formatResultForLog(finalResult) - appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) - - return finalResult, nil -} diff --git a/core/mcp/codemode/starlark/getdocs.go b/core/mcp/codemode/starlark/getdocs.go index 61e4b1dc86..ea622bf7cf 100644 --- a/core/mcp/codemode/starlark/getdocs.go +++ b/core/mcp/codemode/starlark/getdocs.go @@ -71,8 +71,6 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem var matchedTool *schemas.ChatTool serverNameLower := strings.ToLower(serverName) - toolNameLower := strings.ToLower(toolName) - for clientName, tools := range availableToolsPerClient { client := s.clientManager.GetClientByName(clientName) if client == nil { @@ -90,10 +88,7 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem // Find the specific tool for i, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for comparison - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - if strings.ToLower(unprefixedToolName) == toolNameLower { + if matchesToolReference(toolName, clientName, tool.Function.Name) { matchedTool = &tools[i] break } @@ -125,9 +120,7 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem var availableTools []string for _, tool := range tools { if tool.Function != nil { - unprefixedToolName := stripClientPrefix(tool.Function.Name, matchedClientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableTools = append(availableTools, unprefixedToolName) + availableTools = append(availableTools, getCanonicalToolName(matchedClientName, tool.Function.Name)) } } errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName) @@ -150,7 +143,7 @@ func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isTool // Write comprehensive header sb.WriteString("# ============================================================================\n") if isToolLevel && len(tools) == 1 && tools[0].Function != nil { - sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, tools[0].Function.Name)) + sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, getCanonicalToolName(clientName, tools[0].Function.Name))) } else { sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName)) } @@ -187,9 +180,7 @@ func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isTool } originalToolName := tool.Function.Name - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - toolName := parseToolName(unprefixedToolName) + toolName := getCanonicalToolName(clientName, originalToolName) description := "" if tool.Function.Description != nil { description = *tool.Function.Description diff --git a/core/mcp/codemode/starlark/listfiles.go b/core/mcp/codemode/starlark/listfiles.go index caff015194..4d6aa73add 100644 --- a/core/mcp/codemode/starlark/listfiles.go +++ b/core/mcp/codemode/starlark/listfiles.go @@ -21,7 +21,8 @@ func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool { if bindingLevel == schemas.CodeModeBindingLevelServer { description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " + "Each server has a corresponding file (e.g., servers/.pyi) that contains compact Python signatures for all tools in that server. " + - "Use readToolFile to read a specific server file and see all available tools with their signatures. " + + "Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "Use readToolFile before executeToolCode to read a specific server file and confirm exact callable tool names and parameters. " + "Use getToolDocs if you need detailed documentation for a specific tool. " + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + @@ -30,7 +31,9 @@ func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool { } else { description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " + "Each tool has a corresponding file (e.g., servers//.pyi) that contains compact Python signatures for that specific tool. " + - "Use readToolFile to read a specific tool file and see its signature. " + + "The shown in each filename is the exact canonical identifier exposed in executeToolCode. " + + "Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "Use readToolFile before executeToolCode to confirm the exact signature and parameters for the tool you want to call. " + "Use getToolDocs if you need detailed documentation for a specific tool. " + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + @@ -88,12 +91,7 @@ func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall sch // Tool-level: one file per tool for _, tool := range tools { if tool.Function != nil && tool.Function.Name != "" { - // Strip the client prefix from tool name (format: "client-toolname" -> "toolname") - // But replace - with _ for valid Python identifiers - toolName := stripClientPrefix(tool.Function.Name, clientName) - // Replace any remaining hyphens with underscores for Python compatibility - toolName = strings.ReplaceAll(toolName, "-", "_") - // Validate normalized tool name to prevent path traversal + toolName := getCanonicalToolName(clientName, tool.Function.Name) if err := validateNormalizedToolName(toolName); err != nil { s.logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err) continue @@ -112,10 +110,32 @@ func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall sch } // Build tree structure from file list - responseText := buildVFSTree(files) + responseText := buildListToolFilesResponse(files, bindingLevel) return createToolResponseMessage(toolCall, responseText), nil } +func buildListToolFilesResponse(files []string, bindingLevel schemas.CodeModeBindingLevel) string { + tree := buildVFSTree(files) + if tree == "" { + return "" + } + + header := []string{ + "# Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode", + } + + if bindingLevel == schemas.CodeModeBindingLevelServer { + header = append(header, "# Read the server .pyi file before executeToolCode to confirm exact tool names and parameters.") + } else { + header = append(header, + "# Filenames below use the exact canonical tool identifiers available in executeToolCode.", + "# Still call readToolFile before executeToolCode to confirm parameters and return shape.", + ) + } + + return strings.Join(append(header, "", tree), "\n") +} + // VFS tree node structure for building hierarchical file structure type treeNode struct { isDirectory bool diff --git a/core/mcp/codemode/starlark/readfile.go b/core/mcp/codemode/starlark/readfile.go index 5940e0c9fc..41063ad065 100644 --- a/core/mcp/codemode/starlark/readfile.go +++ b/core/mcp/codemode/starlark/readfile.go @@ -21,21 +21,23 @@ func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool { var fileNameDescription, toolDescription string if bindingLevel == schemas.CodeModeBindingLevelServer { - fileNameDescription = "The virtual filename from listToolFiles in format: servers/.pyi (e.g., 'calculator.pyi')" + fileNameDescription = "The virtual filename from listToolFiles in format: servers/.pyi (e.g., 'servers/calculator.pyi')" toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " + "for all tools available on that server. The fileName should be in format servers/.pyi as listed by listToolFiles. " + "The function performs case-insensitive matching and removes the .pyi extension. " + + "This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " + "Each tool can be accessed in code via: serverName.tool_name(param=value). " + "If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + "do NOT call this tool again with startLine/endLine - you already have the complete file." } else { - fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'calculator/add.pyi')" + fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'servers/calculator/add.pyi')" toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " + "The fileName should be in format servers//.pyi as listed by listToolFiles. " + "The function performs case-insensitive matching and removes the .pyi extension. " + - "The tool can be accessed in code via: serverName.tool_name(param=value). " + + "This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " + + "The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " + "If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + @@ -126,13 +128,9 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche if isToolLevel { // Tool-level: filter to specific tool var foundTool *schemas.ChatTool - toolNameLower := strings.ToLower(toolName) for i, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for comparison - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - if strings.ToLower(unprefixedToolName) == toolNameLower { + if matchesToolReference(toolName, clientName, tool.Function.Name) { foundTool = &tools[i] break } @@ -143,15 +141,12 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche availableTools := make([]string, 0) for _, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableTools = append(availableTools, unprefixedToolName) + availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name)) } } errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName) for _, t := range availableTools { - errorMsg += fmt.Sprintf(" - %s/%s.pyi\n", clientName, t) + errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t) } return createToolResponseMessage(toolCall, errorMsg), nil } @@ -171,17 +166,14 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche for name := range availableToolsPerClient { if bindingLevel == schemas.CodeModeBindingLevelServer { - availableFiles = append(availableFiles, fmt.Sprintf("%s.pyi", name)) + availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name)) } else { client := s.clientManager.GetClientByName(name) if client != nil && client.ExecutionConfig.IsCodeModeClient { if tools, ok := availableToolsPerClient[name]; ok { for _, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, name) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableFiles = append(availableFiles, fmt.Sprintf("%s/%s.pyi", name, unprefixedToolName)) + availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name))) } } } @@ -295,12 +287,14 @@ func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isTo // Minimal header if isToolLevel && len(tools) == 1 && tools[0].Function != nil { - toolName := parseToolName(stripClientPrefix(tools[0].Function.Name, clientName)) + toolName := getCanonicalToolName(clientName, tools[0].Function.Name) sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName)) } else { sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName)) } sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName)) + sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n") + sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n") sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName)) sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n") @@ -309,10 +303,7 @@ func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isTo continue } - // Strip client prefix and replace - with _ for code mode compatibility - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - toolName := parseToolName(unprefixedToolName) + toolName := getCanonicalToolName(clientName, tool.Function.Name) // Format inline parameters in Python style params := formatPythonParams(tool.Function.Parameters) diff --git a/core/mcp/codemode/starlark/starlark.go b/core/mcp/codemode/starlark/starlark.go index 0da1d2ccd9..348655b983 100644 --- a/core/mcp/codemode/starlark/starlark.go +++ b/core/mcp/codemode/starlark/starlark.go @@ -6,7 +6,6 @@ package starlark import ( - "context" "fmt" "sync" "sync/atomic" @@ -111,7 +110,7 @@ func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool { // Returns: // - *schemas.ChatMessage: The tool response message // - error: Any error that occurred during execution -func (s *StarlarkCodeMode) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { if toolCall.Function.Name == nil { return nil, fmt.Errorf("tool call missing function name") } diff --git a/core/mcp/codemode/starlark/starlark_test.go b/core/mcp/codemode/starlark/starlark_test.go index dba557f88a..a48e77a887 100644 --- a/core/mcp/codemode/starlark/starlark_test.go +++ b/core/mcp/codemode/starlark/starlark_test.go @@ -3,13 +3,42 @@ package starlark import ( + "context" + "strings" "testing" + "time" "github.com/bytedance/sonic" + codemcp "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" "go.starlark.net/starlark" + "go.starlark.net/syntax" ) +type testClientManager struct { + clients map[string]*schemas.MCPClientState + tools map[string][]schemas.ChatTool +} + +func (m *testClientManager) GetClientForTool(toolName string) *schemas.MCPClientState { + for clientName, tools := range m.tools { + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name == toolName { + return m.clients[clientName] + } + } + } + return nil +} + +func (m *testClientManager) GetClientByName(clientName string) *schemas.MCPClientState { + return m.clients[clientName] +} + +func (m *testClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + return m.tools +} + func TestStarlarkToGo(t *testing.T) { t.Run("Convert None", func(t *testing.T) { result := starlarkToGo(starlark.None) @@ -151,6 +180,83 @@ func TestGoToStarlark(t *testing.T) { }) } +func TestGetCanonicalToolName(t *testing.T) { + if got := getCanonicalToolName("github", "github-SEARCH_REPOS"); got != "search_repos" { + t.Fatalf("expected canonical tool name search_repos, got %q", got) + } + + if got := getCanonicalToolName("math", "math-123Add!"); got != "_123add" { + t.Fatalf("expected canonical tool name _123add, got %q", got) + } +} + +func TestMatchesToolReferenceSupportsCanonicalAndLegacyNames(t *testing.T) { + clientName := "github" + originalToolName := "github-SEARCH_REPOS" + + testCases := []string{ + "search_repos", + "SEARCH_REPOS", + } + + for _, toolRef := range testCases { + if !matchesToolReference(toolRef, clientName, originalToolName) { + t.Fatalf("expected %q to match %q", toolRef, originalToolName) + } + } +} + +func TestHandleListToolFilesUsesCanonicalToolIdentifiers(t *testing.T) { + mode := NewStarlarkCodeMode(&codemcp.CodeModeConfig{ + BindingLevel: schemas.CodeModeBindingLevelTool, + ToolExecutionTimeout: time.Second, + }, nil) + + clientName := "github" + mode.clientManager = &testClientManager{ + clients: map[string]*schemas.MCPClientState{ + clientName: { + Name: clientName, + ExecutionConfig: &schemas.MCPClientConfig{ + Name: clientName, + IsCodeModeClient: true, + }, + }, + }, + tools: map[string][]schemas.ChatTool{ + clientName: { + { + Function: &schemas.ChatToolFunction{ + Name: "github-SEARCH_REPOS", + }, + }, + }, + }, + } + + msg, err := mode.handleListToolFiles(context.Background(), schemas.ChatAssistantMessageToolCall{ + ID: schemas.Ptr("tool-call-1"), + }) + if err != nil { + t.Fatalf("handleListToolFiles returned error: %v", err) + } + + if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil { + t.Fatal("expected tool response content") + } + + content := *msg.Content.ContentStr + if !strings.Contains(content, "search_repos.pyi") { + t.Fatalf("expected canonical tool file path in response, got:\n%s", content) + } + if strings.Contains(content, "SEARCH_REPOS.pyi") { + t.Fatalf("did not expect raw uppercase tool file path in response, got:\n%s", content) + } + if !strings.Contains(content, "readToolFile before executeToolCode") { + t.Fatalf("expected workflow guidance in response, got:\n%s", content) + } +} + func TestGeneratePythonErrorHints(t *testing.T) { serverKeys := []string{"calculator", "weather"} @@ -161,13 +267,13 @@ func TestGeneratePythonErrorHints(t *testing.T) { } found := false for _, hint := range hints { - if containsAny(hint, "not defined", "undefined") { + if strings.Contains(hint, "Variable 'foo' is not defined.") { found = true break } } if !found { - t.Error("Expected hint about undefined variable") + t.Errorf("Expected exact undefined variable hint for foo, got: %v", hints) } }) @@ -489,3 +595,405 @@ func TestFormatResultForLog(t *testing.T) { } }) } + +// starlarkOpts returns the FileOptions used by the code mode executor. +// Kept in sync with executecode.go to test the same dialect configuration. +func starlarkOpts() *syntax.FileOptions { + return &syntax.FileOptions{ + TopLevelControl: true, + While: true, + Set: true, + GlobalReassign: true, + Recursion: true, + } +} + +// execStarlark is a test helper that executes Starlark code with our dialect options +// and returns the globals and any error. +func execStarlark(code string) (starlark.StringDict, error) { + thread := &starlark.Thread{Name: "test"} + return starlark.ExecFileOptions(starlarkOpts(), thread, "test.star", code, nil) +} + +func TestStarlarkDialectOptions(t *testing.T) { + t.Run("Top-level for loop", func(t *testing.T) { + code := ` +items = [] +for i in range(3): + items.append(i) +result = items +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level for loop should work, got error: %v", err) + } + resultVal := globals["result"] + list, ok := resultVal.(*starlark.List) + if !ok { + t.Fatalf("Expected list, got %T", resultVal) + } + if list.Len() != 3 { + t.Errorf("Expected 3 items, got %d", list.Len()) + } + }) + + t.Run("Top-level if statement", func(t *testing.T) { + code := ` +x = 10 +if x > 5: + result = "big" +else: + result = "small" +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level if should work, got error: %v", err) + } + if globals["result"] != starlark.String("big") { + t.Errorf("Expected 'big', got %v", globals["result"]) + } + }) + + t.Run("Top-level while loop", func(t *testing.T) { + code := ` +count = 0 +while count < 5: + count += 1 +result = count +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level while loop should work, got error: %v", err) + } + resultVal := globals["result"] + if resultVal.String() != "5" { + t.Errorf("Expected 5, got %v", resultVal) + } + }) + + t.Run("While loop inside function", func(t *testing.T) { + code := ` +def countdown(n): + items = [] + while n > 0: + items.append(n) + n -= 1 + return items +result = countdown(3) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("While in function should work, got error: %v", err) + } + list := globals["result"].(*starlark.List) + if list.Len() != 3 { + t.Errorf("Expected 3 items, got %d", list.Len()) + } + }) + + t.Run("set() builtin", func(t *testing.T) { + code := ` +s = set([1, 2, 3, 2, 1]) +result = len(s) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("set() should work, got error: %v", err) + } + if globals["result"].String() != "3" { + t.Errorf("Expected 3 unique items, got %v", globals["result"]) + } + }) + + t.Run("Global variable reassignment", func(t *testing.T) { + code := ` +x = 1 +x = x + 1 +x = x * 3 +result = x +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Global reassignment should work, got error: %v", err) + } + if globals["result"].String() != "6" { + t.Errorf("Expected 6, got %v", globals["result"]) + } + }) + + t.Run("Recursive function", func(t *testing.T) { + code := ` +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) +result = factorial(5) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Recursion should work, got error: %v", err) + } + if globals["result"].String() != "120" { + t.Errorf("Expected 120, got %v", globals["result"]) + } + }) +} + +func TestStarlarkStringEscapePreservation(t *testing.T) { + t.Run("Backslash-n in string literal preserved", func(t *testing.T) { + // Simulate what happens after JSON deserialization: + // Model writes: {"code": "msg = \"hello\\nworld\""} + // sonic.Unmarshal produces: msg = "hello\nworld" (where \n is two chars: \ + n) + // Starlark should interpret \n as newline escape inside the string + code := "msg = \"hello\\nworld\"\nresult = msg" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("String with \\n escape should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "hello\nworld" { + t.Errorf("Expected 'helloworld', got %q", resultStr) + } + }) + + t.Run("Multiple escape sequences in strings", func(t *testing.T) { + code := "msg = \"col1\\tcol2\\nrow1\\trow2\"\nresult = msg" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("String with multiple escapes should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "col1\tcol2\nrow1\trow2" { + t.Errorf("Expected tab/newline escapes, got %q", resultStr) + } + }) + + t.Run("Newline join pattern", func(t *testing.T) { + // This is the exact pattern that failed 7 times in benchmarks + code := ` +def main(): + lines = ["line1", "line2", "line3"] + content = "\n".join(lines) + return content +result = main() +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Newline join pattern should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "line1\nline2\nline3" { + t.Errorf("Expected joined lines, got %q", resultStr) + } + }) + + t.Run("chr() for newline", func(t *testing.T) { + code := ` +nl = chr(10) +result = "hello" + nl + "world" +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("chr(10) should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "hello\nworld" { + t.Errorf("Expected 'helloworld', got %q", resultStr) + } + }) + + t.Run("Triple-quoted strings", func(t *testing.T) { + code := "result = \"\"\"line1\nline2\nline3\"\"\"" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Triple-quoted string should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "line1\nline2\nline3" { + t.Errorf("Expected multiline string, got %q", resultStr) + } + }) + + t.Run("Raw string preserves backslash", func(t *testing.T) { + code := "result = r\"hello\\nworld\"" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Raw string should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + // Raw string: \n stays as two characters \ and n + if resultStr != "hello\\nworld" { + t.Errorf("Expected literal backslash-n, got %q", resultStr) + } + }) + + t.Run("JSON deserialization then Starlark execution", func(t *testing.T) { + // End-to-end: simulate the exact flow from model JSON → sonic.Unmarshal → Starlark + jsonArgs := `{"code": "lines = [\"a\", \"b\", \"c\"]\nresult = \"\\n\".join(lines)"}` + + var arguments map[string]interface{} + err := sonic.Unmarshal([]byte(jsonArgs), &arguments) + if err != nil { + t.Fatalf("JSON unmarshal failed: %v", err) + } + + code := arguments["code"].(string) + + globals, starlarkErr := execStarlark(code) + if starlarkErr != nil { + t.Fatalf("Starlark execution failed: %v", starlarkErr) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "a\nb\nc" { + t.Errorf("Expected 'abc', got %q", resultStr) + } + }) +} + +func TestStarlarkUnsupportedFeatures(t *testing.T) { + t.Run("try/except rejected", func(t *testing.T) { + code := ` +def main(): + try: + x = 1 + except: + x = 0 +result = main() +` + _, err := execStarlark(code) + if err == nil { + t.Fatal("try/except should be rejected by Starlark") + } + if !strings.Contains(err.Error(), "got try") { + t.Errorf("Expected 'got try' in error, got: %v", err) + } + }) + + t.Run("raise rejected", func(t *testing.T) { + code := `raise ValueError("test")` + + _, err := execStarlark(code) + if err == nil { + t.Fatal("raise should be rejected by Starlark") + } + }) + + t.Run("class rejected", func(t *testing.T) { + code := ` +class Foo: + pass +` + _, err := execStarlark(code) + if err == nil { + t.Fatal("class should be rejected by Starlark") + } + }) + + t.Run("import rejected", func(t *testing.T) { + code := `import json` + + _, err := execStarlark(code) + if err == nil { + t.Fatal("import should be rejected by Starlark") + } + }) +} + +func TestGeneratePythonErrorHintsNewCases(t *testing.T) { + serverKeys := []string{"Github", "SqLite"} + + t.Run("try/except hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:3:9: got try, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for try/except error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about try/except not being supported, got: %v", hints) + } + }) + + t.Run("except hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:5:9: got except, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for except error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("finally hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:7:9: got finally, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for finally error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("raise hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:2:1: got raise, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for raise error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("Undefined variable includes scope hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:3:17: undefined: commits_n8n", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for undefined variable") + } + foundVar := false + foundScope := false + for _, hint := range hints { + if strings.Contains(hint, "Variable 'commits_n8n' is not defined.") { + foundVar = true + } + if containsAny(hint, "fresh scope", "persist") { + foundScope = true + } + } + if !foundVar { + t.Errorf("Expected exact undefined variable hint for commits_n8n, got: %v", hints) + } + if !foundScope { + t.Errorf("Expected scope persistence hint, got: %v", hints) + } + }) +} diff --git a/core/mcp/codemode/starlark/utils.go b/core/mcp/codemode/starlark/utils.go index 5b7c9ab920..aea6e732d3 100644 --- a/core/mcp/codemode/starlark/utils.go +++ b/core/mcp/codemode/starlark/utils.go @@ -191,11 +191,25 @@ func formatResultForLog(result interface{}) string { func generatePythonErrorHints(errorMessage string, serverKeys []string) []string { hints := []string{} - if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") { - re := regexp.MustCompile(`(\w+).*(?:undefined|not defined)`) - if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { - undefinedVar := match[1] + if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") || + strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") { + hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.") + hints = append(hints, "Instead, check return values for errors:") + hints = append(hints, " result = server.tool(param=\"value\")") + hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):") + hints = append(hints, " print(\"Error:\", result)") + } else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") { + var undefinedVar string + if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } + if undefinedVar != "" { hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar)) + hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.") if len(serverKeys) > 0 { hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")") @@ -298,7 +312,7 @@ func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, re } } -// parseToolName parses the tool name to be JavaScript-compatible. +// parseToolName normalizes a raw tool name into a Starlark-compatible identifier. func parseToolName(toolName string) string { if toolName == "" { return "" @@ -349,6 +363,61 @@ func parseToolName(toolName string) string { return parsed } +// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark. +func getCanonicalToolName(clientName, originalToolName string) string { + return parseToolName(stripClientPrefix(originalToolName, clientName)) +} + +// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name. +// This is used as a compatibility alias when the raw name is still a valid Starlark identifier. +func getCompatibilityToolAlias(clientName, originalToolName string) string { + return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_") +} + +// matchesToolReference reports whether the requested tool name matches any supported identifier form. +// We accept the canonical callable name plus legacy display forms for backward compatibility. +func matchesToolReference(requestedToolName, clientName, originalToolName string) bool { + requested := strings.ToLower(requestedToolName) + if requested == "" { + return false + } + + candidates := []string{ + getCanonicalToolName(clientName, originalToolName), + getCompatibilityToolAlias(clientName, originalToolName), + stripClientPrefix(originalToolName, clientName), + } + + for _, candidate := range candidates { + if candidate != "" && requested == strings.ToLower(candidate) { + return true + } + } + + return false +} + +// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code. +func isValidStarlarkIdentifier(name string) bool { + if name == "" { + return false + } + + runes := []rune(name) + first := runes[0] + if !unicode.IsLetter(first) && first != '_' && first != '$' { + return false + } + + for _, r := range runes[1:] { + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' { + return false + } + } + + return true +} + // validateNormalizedToolName validates a normalized tool name to prevent path traversal. func validateNormalizedToolName(normalizedName string) error { if normalizedName == "" { diff --git a/core/mcp/healthmonitor.go b/core/mcp/healthmonitor.go index aa6595fe7a..85769afdcb 100644 --- a/core/mcp/healthmonitor.go +++ b/core/mcp/healthmonitor.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" "github.com/maximhq/bifrost/core/schemas" ) @@ -140,9 +141,14 @@ func (chm *ClientHealthMonitor) performHealthCheck() { } chm.mu.Unlock() - // Get the client connection + // Get the client connection — capture Conn while holding the lock so we + // don't race with removeClientUnsafe zeroing it under the write lock. chm.manager.mu.RLock() clientState, exists := chm.manager.clientMap[chm.clientID] + var conn *client.Client + if exists && clientState != nil { + conn = clientState.Conn + } chm.manager.mu.RUnlock() if !exists { @@ -151,7 +157,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { } var err error - if clientState.Conn == nil { + if conn == nil { // No active connection — treat as a health check failure err = fmt.Errorf("no active connection") } else { @@ -160,7 +166,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { defer cancel() if chm.isPingAvailable { - err = clientState.Conn.Ping(ctx) + err = conn.Ping(ctx) } else { listRequest := mcp.ListToolsRequest{ PaginatedRequest: mcp.PaginatedRequest{ @@ -169,7 +175,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { }, }, } - _, err = clientState.Conn.ListTools(ctx, listRequest) + _, err = conn.ListTools(ctx, listRequest) } } diff --git a/core/mcp/interface.go b/core/mcp/interface.go index 93617ce511..3069e2692a 100644 --- a/core/mcp/interface.go +++ b/core/mcp/interface.go @@ -14,15 +14,17 @@ import ( type MCPManagerInterface interface { // Tool Operations // AddToolsToRequest parses available MCP tools and adds them to the request - AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest + AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest // GetAvailableTools returns all available MCP tools for the given context - GetAvailableTools(ctx context.Context) []schemas.ChatTool + GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool // ExecuteToolCall executes a single tool call and returns the result ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) - // UpdateToolManagerConfig updates the configuration for the tool manager + // UpdateToolManagerConfig updates the configuration for the tool manager. + // DisableAutoToolInject in the config controls auto injection — pass the + // current value whenever only other fields change so it is never silently reset. UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) // Agent Mode Operations @@ -60,6 +62,14 @@ type MCPManagerInterface interface { // ReconnectClient reconnects an MCP client by ID ReconnectClient(id string) error + // VerifyPerUserOAuthConnection creates a temporary MCP connection using a + // test access token to verify connectivity and discover tools. The connection + // is closed after verification. + VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) + + // SetClientTools updates the tool map and name mapping for an existing client. + SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) + // Tool Registration // RegisterTool registers a local tool with the MCP server RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index b86409ef1a..adcf3158ac 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -21,13 +21,6 @@ const ( BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment - - // Context keys for client filtering in requests - // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). - // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. - // Request context filtering takes priority over client config - context can override client exclusions. - MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering - MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) ) // ============================================================================ @@ -110,7 +103,7 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider } } - manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, logger) + manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, oauth2Provider, logger) // Set up CodeMode if provided - inject dependencies after manager is created if codeMode != nil { @@ -149,7 +142,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider manager.clientMap[clientConfig.ID].State = schemas.MCPConnectionStateDisconnected } manager.mu.Unlock() - monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, clientConfig.IsPingAvailable, manager.logger) + isPingAvailable := true + if clientConfig.IsPingAvailable != nil { + isPingAvailable = *clientConfig.IsPingAvailable + } + monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, isPingAvailable, manager.logger) manager.healthMonitorManager.StartMonitoring(monitor) } }(clientConfig) @@ -160,6 +157,13 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider return manager } +// SetPluginPipeline updates the plugin pipeline provider and release function on the manager's +// ToolsManager and CodeMode. Call this after attaching an externally-created MCPManager to a Bifrost +// instance so that nested tool calls in code mode can run through Bifrost's plugin hooks. +func (manager *MCPManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) { + manager.toolsManager.SetPluginPipeline(provider, release) +} + // AddToolsToRequest parses available MCP tools from the context and adds them to the request. // It respects context-based filtering for clients and tools, and returns the modified request // with tools attached. @@ -170,11 +174,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider // // Returns: // - *schemas.BifrostRequest: The request with tools added -func (m *MCPManager) AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { +func (m *MCPManager) AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest { return m.toolsManager.ParseAndAddToolsToRequest(ctx, req) } -func (m *MCPManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { +func (m *MCPManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool { return m.toolsManager.GetAvailableTools(ctx) } diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go index 8ba68e65ab..0527b1d46c 100644 --- a/core/mcp/toolmanager.go +++ b/core/mcp/toolmanager.go @@ -5,13 +5,17 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" "sync/atomic" "time" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -31,11 +35,15 @@ type PluginPipeline interface { // ToolsManager manages MCP tool execution and agent mode. type ToolsManager struct { - toolExecutionTimeout atomic.Value - maxAgentDepth atomic.Int32 - clientManager ClientManager - logger schemas.Logger - agentModeExecutor *AgentModeExecutor + toolExecutionTimeout atomic.Value + maxAgentDepth atomic.Int32 + disableAutoToolInject atomic.Bool + clientManager ClientManager + logger schemas.Logger + agentModeExecutor *AgentModeExecutor + + // OAuth2Provider for per-user OAuth token management + oauth2Provider schemas.OAuth2Provider // CodeMode implementation for code execution (Starlark by default) codeMode CodeMode @@ -73,6 +81,7 @@ func NewToolsManager( fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, pluginPipelineProvider func() PluginPipeline, releasePluginPipeline func(pipeline PluginPipeline), + oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, ) *ToolsManager { return NewToolsManagerWithCodeMode( @@ -82,6 +91,7 @@ func NewToolsManager( pluginPipelineProvider, releasePluginPipeline, nil, // Use default code mode (will be set later via SetCodeMode) + oauth2Provider, logger, ) } @@ -106,6 +116,7 @@ func NewToolsManagerWithCodeMode( pluginPipelineProvider func() PluginPipeline, releasePluginPipeline func(pipeline PluginPipeline), codeMode CodeMode, + oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, ) *ToolsManager { if config == nil { @@ -142,11 +153,13 @@ func NewToolsManagerWithCodeMode( codeMode: codeMode, logger: logger, agentModeExecutor: agentModeExecutor, + oauth2Provider: oauth2Provider, } // Initialize atomic values manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout) manager.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + manager.disableAutoToolInject.Store(config.DisableAutoToolInject) manager.logger.Info("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) return manager @@ -174,8 +187,20 @@ func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies { } } +// SetPluginPipeline updates the plugin pipeline provider and release function +// on both the ToolsManager and its CodeMode implementation. +// This is used when an externally-created MCPManager is attached to a Bifrost instance +// via SetMCPManager, so the CodeMode can route nested tool calls through Bifrost's plugin hooks. +func (m *ToolsManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) { + m.pluginPipelineProvider = provider + m.releasePluginPipeline = release + if m.codeMode != nil { + m.codeMode.SetDependencies(m.GetCodeModeDependencies()) + } +} + // GetAvailableTools returns the available tools for the given context. -func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { +func (m *ToolsManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool { availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) // Flatten tools from all clients into a single slice, avoiding duplicates var availableTools []schemas.ChatTool @@ -191,14 +216,14 @@ func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool } if client.ExecutionConfig.IsCodeModeClient { includeCodeModeTools = true - } else { - // Add tools from this client, checking for duplicates - for _, tool := range clientTools { - if tool.Function != nil && tool.Function.Name != "" { - if !seenToolNames[tool.Function.Name] { - availableTools = append(availableTools, tool) - seenToolNames[tool.Function.Name] = true - } + } + // Add tools from this client, checking for duplicates + for _, tool := range clientTools { + if tool.Function != nil && tool.Function.Name != "" && !seenToolNames[tool.Function.Name] { + seenToolNames[tool.Function.Name] = true + schemas.AppendToContextList(ctx, schemas.BifrostContextKeyMCPAddedTools, tool.Function.Name) + if !client.ExecutionConfig.IsCodeModeClient { + availableTools = append(availableTools, tool) } } } @@ -288,12 +313,22 @@ func buildIntegrationDuplicateCheckMap(existingTools []schemas.ChatTool, integra // // Returns: // - *schemas.BifrostRequest: Bifrost request with MCP tools added -func (m *ToolsManager) ParseAndAddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { +func (m *ToolsManager) ParseAndAddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest { // MCP is only supported for chat and responses requests if req.ChatRequest == nil && req.ResponsesRequest == nil { return req } + // When auto tool injection is disabled, only inject tools if the request + // has explicit context filters set (e.g. via x-bf-mcp-include-tools header). + if m.disableAutoToolInject.Load() { + includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) + includeClients := ctx.Value(schemas.MCPContextKeyIncludeClients) + if includeTools == nil && includeClients == nil { + return req + } + } + availableTools := m.GetAvailableTools(ctx) if len(availableTools) == 0 { @@ -541,9 +576,90 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall Name: originalMCPToolName, Arguments: arguments, }, + Header: utils.GetHeadersForToolExecution(ctx, client), } - if client.ExecutionConfig.Headers != nil { + // Handle per-user OAuth: inject user-specific Authorization header + if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth { + if m.oauth2Provider == nil { + return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured") + } + virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string) + userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string) + sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string) + + // Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key + if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" { + userID = mcpUserID + } + + // Try identity-based token lookup first (works even without session token) + accessToken, err := m.oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID) + if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) { + // Had session but token lookup failed with a real error (not just "not found") — return error + return nil, "", "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err) + } + if err != nil { + // No token found — user hasn't authenticated with this MCP server yet. + // In LLM gateway mode with no identity, we can't track who this user is, + // so an OAuth flow would produce an orphaned token. Return a clear error instead. + isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool) + if !isMCPGateway && userID == "" && virtualKeyID == "" { + return nil, "", "", fmt.Errorf( + "per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you", + client.ExecutionConfig.Name, + ) + } + + // Initiate OAuth flow to get a proper authorize URL with session tracking. + if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" { + return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name) + } + redirectURI := buildRedirectURIFromContext(ctx) + if redirectURI == "" { + return nil, "", "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context") + } + flowInitiation, sessionID, flowErr := m.oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI) + if flowErr != nil { + return nil, "", "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr) + } + return nil, "", "", &schemas.MCPUserOAuthRequiredError{ + MCPClientID: client.ExecutionConfig.ID, + MCPClientName: client.ExecutionConfig.Name, + AuthorizeURL: flowInitiation.AuthorizeURL, + SessionID: sessionID, + Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name), + } + } + + if client.Conn == nil { + // No persistent connection — create temporary connection with user's token + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := executeToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger) + if callErr != nil { + if toolCtx.Err() == context.DeadlineExceeded { + return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) + } + m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) + return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr) + } + responseText := extractTextFromMCPResponse(toolResponse, toolName) + return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil + } + + // Persistent connection exists — use per-call headers + headers := make(http.Header) + if client.ExecutionConfig.Headers != nil { + for key, value := range client.ExecutionConfig.Headers { + headers.Add(key, value.GetValue()) + } + } + headers.Set("Authorization", "Bearer "+accessToken) + callRequest.Header = headers + } else if client.ExecutionConfig.Headers != nil { headers := make(http.Header) for key, value := range client.ExecutionConfig.Headers { headers.Add(key, value.GetValue()) @@ -660,17 +776,99 @@ func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { m.maxAgentDepth.Store(int32(config.MaxAgentDepth)) } - // Update CodeMode configuration if present - if m.codeMode != nil && config.CodeModeBindingLevel != "" { + // Update CodeMode configuration — propagate whenever either field is set + if m.codeMode != nil && (config.CodeModeBindingLevel != "" || config.ToolExecutionTimeout > 0) { m.codeMode.UpdateConfig(&CodeModeConfig{ BindingLevel: config.CodeModeBindingLevel, ToolExecutionTimeout: config.ToolExecutionTimeout, }) } + m.disableAutoToolInject.Store(config.DisableAutoToolInject) + m.logger.Info("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) } +// executeToolWithUserToken creates a temporary MCP connection using the user's +// OAuth access token, calls the specified tool, and closes the connection. +// This is used for per_user_oauth clients which have no persistent connection — +// each tool call gets its own short-lived connection authenticated with the +// requesting user's token. +// +// Parameters: +// - ctx: context with timeout for the entire operation +// - config: MCP client configuration (connection URL, name) +// - toolName: original MCP tool name to call +// - arguments: tool call arguments +// - accessToken: user's OAuth access token +// - logger: logger instance +// +// Returns: +// - *mcp.CallToolResult: tool execution result +// - error: any error during connection or execution +func executeToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) { + if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" { + return nil, fmt.Errorf("connection URL is required for per-user OAuth tool execution") + } + + // Create HTTP transport with the user's Bearer token, preserving configured headers + headers := make(map[string]string) + if config.Headers != nil { + for key, value := range config.Headers { + headers[key] = value.GetValue() + } + } + headers["Authorization"] = "Bearer " + accessToken + httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + // Create temporary MCP client + tempClient := client.NewClient(httpTransport) + if err := tempClient.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start temporary MCP connection: %w", err) + } + defer tempClient.Close() + + // Initialize MCP handshake + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s-user", config.Name), + Version: "1.0.0", + }, + }, + } + if _, err := tempClient.Initialize(ctx, initRequest); err != nil { + return nil, fmt.Errorf("failed to initialize temporary MCP connection: %w", err) + } + + // Call the tool + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }, + } + return tempClient.CallTool(ctx, callRequest) +} + +// buildRedirectURIFromContext extracts the OAuth redirect URI from context. +// The URI is set by the HTTP middleware from the request's host. +func buildRedirectURIFromContext(ctx *schemas.BifrostContext) string { + if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" { + return uri + } + // Fallback — should not happen if middleware is configured correctly + return "" +} + // GetCodeModeBindingLevel returns the current code mode binding level. // This method is safe to call concurrently from multiple goroutines. func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel { diff --git a/core/mcp/utils.go b/core/mcp/utils.go index 1356bb38bb..479bc0bad8 100644 --- a/core/mcp/utils.go +++ b/core/mcp/utils.go @@ -65,7 +65,7 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. var includeClients []string // Extract client filtering from request context - if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { includeClients = existingIncludeClients } @@ -381,12 +381,12 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b // If ToolsToExecute is specified (not nil), apply filtering if config.ToolsToExecute != nil { // Handle empty array [] - means no tools are allowed - if len(config.ToolsToExecute) == 0 { + if config.ToolsToExecute.IsEmpty() { return true // No tools allowed } // Handle wildcard "*" - if present, all tools are allowed - if slices.Contains(config.ToolsToExecute, "*") { + if config.ToolsToExecute.IsUnrestricted() { return false // All tools allowed } @@ -396,7 +396,7 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the allowed list - return !slices.Contains(config.ToolsToExecute, unprefixedToolName) // Tool not in allowed list + return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list } return true // Tool is skipped (nil is treated as [] - no tools) @@ -413,12 +413,12 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // If ToolsToAutoExecute is specified (not nil), apply filtering if config.ToolsToAutoExecute != nil { // Handle empty array [] - means no tools are auto-executed - if len(config.ToolsToAutoExecute) == 0 { + if config.ToolsToAutoExecute.IsEmpty() { return false // No tools auto-executed } // Handle wildcard "*" - if present, all tools are auto-executed - if slices.Contains(config.ToolsToAutoExecute, "*") { + if config.ToolsToAutoExecute.IsUnrestricted() { return true // All tools auto-executed } @@ -428,7 +428,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the auto-execute list - return slices.Contains(config.ToolsToAutoExecute, unprefixedToolName) + return config.ToolsToAutoExecute.Contains(unprefixedToolName) } return false // Tool is not auto-executed (nil is treated as [] - no tools) @@ -439,7 +439,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // Context filtering can only NARROW the tools available, NOT expand beyond client configuration. // This is checked AFTER client-level filtering (shouldSkipToolForConfig). func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { - includeTools := ctx.Value(MCPContextKeyIncludeTools) + includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) if includeTools != nil { // Try []string first (preferred type) @@ -777,6 +777,7 @@ func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { return true } + // Check if message has tool calls regardless of finish_reason. // Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present, // so we cannot rely solely on finish_reason to detect tool calls. diff --git a/core/mcp/utils/utils.go b/core/mcp/utils/utils.go new file mode 100644 index 0000000000..500792a09f --- /dev/null +++ b/core/mcp/utils/utils.go @@ -0,0 +1,49 @@ +package utils + +import ( + "net/http" + + "github.com/maximhq/bifrost/core/schemas" +) + +// GetHeadersForToolExecution sets additional headers for tool execution. +// It returns the headers for the tool execution. +func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header { + if ctx == nil || client == nil || client.ExecutionConfig == nil { + return make(http.Header) + } + headers := make(http.Header) + if client.ExecutionConfig.Headers != nil { + for key, value := range client.ExecutionConfig.Headers { + headers.Add(key, value.GetValue()) + } + } + // Give priority to extra headers in the context + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok { + filteredHeaders := make(http.Header) + for key, values := range extraHeaders { + if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) { + for i, value := range values { + if i == 0 { + filteredHeaders.Set(key, value) + } else { + filteredHeaders.Add(key, value) + } + } + } + } + // Add the filtered headers to the headers + if len(filteredHeaders) > 0 { + for k, values := range filteredHeaders { + for i, v := range values { + if i == 0 { + headers.Set(k, v) + } else { + headers.Add(k, v) + } + } + } + } + } + return headers +} diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index e012f3a13f..fabb785804 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -173,7 +173,7 @@ func extractAnthropicResponsesUsageFromPrefetch(data []byte) *schemas.ResponsesR // Returns the response body or an error if the request fails. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, latency, nil) — callers must check the context flag. -func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, requestType schemas.RequestType) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -208,7 +208,7 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, requestClient := provider.client responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) - isCountTokens := meta != nil && meta.RequestType == schemas.CountTokensRequest + isCountTokens := requestType == schemas.CountTokensRequest // CountTokens responses are always tiny — skip streaming client so the response // is buffered normally (same approach as OpenAI and Gemini count_tokens handlers). if responseThreshold > 0 && !isCountTokens { @@ -233,20 +233,20 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) - return nil, latency, providerResponseHeaders, parseAnthropicError(resp, meta) + return nil, latency, providerResponseHeaders, parseAnthropicError(resp) } // CountTokens uses buffered response (streaming skipped above) — decode directly. if isCountTokens { body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return body, latency, providerResponseHeaders, nil } // Delegate large response detection + normal buffered path to shared utility - body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { return nil, latency, providerResponseHeaders, respErr } @@ -290,10 +290,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseAnthropicError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseAnthropicError(resp) } // Parse Anthropic's response @@ -304,7 +301,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, } // Create final response - response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered) + response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -355,18 +352,13 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToAnthropicTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use struct directly for JSON marshaling (no beta headers for text completion) - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TextCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), schemas.TextCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -379,9 +371,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -400,9 +389,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostTextCompletionResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -444,8 +430,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k } AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -473,11 +458,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k } // Use struct directly for JSON marshaling - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), schemas.ChatCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -490,9 +471,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -511,9 +489,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostChatResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -550,8 +525,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont anthropicReq.Stream = schemas.Ptr(true) AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -604,11 +578,6 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } @@ -628,7 +597,6 @@ func HandleAnthropicChatCompletionStreaming( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -675,9 +643,9 @@ func HandleAnthropicChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -686,7 +654,7 @@ func HandleAnthropicChatCompletionStreaming( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -703,14 +671,10 @@ func HandleAnthropicChatCompletionStreaming( go func() { defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { - model := "unknown" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -720,7 +684,6 @@ func HandleAnthropicChatCompletionStreaming( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -774,7 +737,7 @@ func HandleAnthropicChatCompletionStreaming( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -833,7 +796,6 @@ func HandleAnthropicChatCompletionStreaming( } } if event.Message != nil { - // Handle different event types modelName = event.Message.Model } @@ -882,11 +844,8 @@ func HandleAnthropicChatCompletionStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -910,22 +869,14 @@ func HandleAnthropicChatCompletionStreaming( response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream(ctx, structuredOutputToolName, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) break } if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -952,7 +903,7 @@ func HandleAnthropicChatCompletionStreaming( usage.PromptTokens = usage.PromptTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens usage.TotalTokens = usage.TotalTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, modelName, 0) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -981,16 +932,12 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, nil) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), schemas.ResponsesRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1008,9 +955,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc Model: request.Model, Usage: extractAnthropicResponsesUsageFromPrefetch([]byte(preview)), ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1030,9 +974,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc bifrostResponse := response.ToBifrostResponsesResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1056,7 +997,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, } // Convert to Anthropic format using the centralized converter - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), true, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, true, nil) if err != nil { return nil, err } @@ -1089,11 +1030,6 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } @@ -1113,7 +1049,6 @@ func HandleAnthropicResponsesStream( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -1162,9 +1097,9 @@ func HandleAnthropicResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1173,7 +1108,7 @@ func HandleAnthropicResponsesStream( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1190,14 +1125,10 @@ func HandleAnthropicResponsesStream( go func() { defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { - model := "" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1207,7 +1138,6 @@ func HandleAnthropicResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1260,7 +1190,7 @@ func HandleAnthropicResponsesStream( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1330,11 +1260,6 @@ func HandleAnthropicResponsesStream( ctx.SetValue(schemas.BifrostContextKeyHasEmittedMessageDelta, true) } if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { return @@ -1351,12 +1276,9 @@ func HandleAnthropicResponsesStream( Type: schemas.ResponsesStreamResponseType(eventType), SequenceNumber: chunkIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), - RawResponse: eventData, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + RawResponse: eventData, }, } lastChunkTime = time.Now() @@ -1370,11 +1292,8 @@ func HandleAnthropicResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -1428,7 +1347,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if len(request.Requests) == 0 { - return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil) } // Create request @@ -1466,7 +1385,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key jsonData, err := providerUtils.MarshalSorted(anthropicReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } usedLargePayloadBody := setAnthropicRequestBody(ctx, req, jsonData) @@ -1486,12 +1405,12 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var anthropicResp AnthropicBatchResponse @@ -1500,7 +1419,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return anthropicResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -1516,7 +1435,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Initialize serial pagination helper (Anthropic uses AfterID for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.AfterID, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -1527,10 +1446,6 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -1579,12 +1494,12 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicBatchListResponse @@ -1597,7 +1512,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(anthropicResp.Data)) var lastBatchID string for _, batch := range anthropicResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, false, false, nil, nil)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, false, false, nil, nil)) lastBatchID = batch.ID } @@ -1611,9 +1526,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -1631,7 +1544,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1672,7 +1585,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1684,7 +1597,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1702,8 +1615,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := anthropicResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := anthropicResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -1718,7 +1630,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1755,7 +1667,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1767,7 +1679,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1790,9 +1702,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys Object: anthropicResp.Type, Status: ToBifrostBatchStatus(anthropicResp.ProcessingStatus), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1835,7 +1745,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1869,7 +1779,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1881,7 +1791,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1923,9 +1833,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2008,7 +1916,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -2022,14 +1930,14 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2061,12 +1969,12 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileUploadRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var anthropicResp AnthropicFileResponse @@ -2077,7 +1985,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s return nil, bifrostErr } - return anthropicResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided keys and aggregates results. @@ -2095,7 +2003,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2106,10 +2014,6 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2155,12 +2059,12 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicFileListResponse @@ -2195,9 +2099,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2216,7 +2118,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2257,7 +2159,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2269,7 +2171,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2287,7 +2189,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - return anthropicResp.ToBifrostFileRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } return nil, lastErr @@ -2302,7 +2204,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2339,7 +2241,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2356,9 +2258,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2368,7 +2268,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2391,9 +2291,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: anthropicResp.Type == "file_deleted", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2421,7 +2319,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -2453,7 +2351,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileContentRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2465,7 +2363,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2485,9 +2383,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2500,16 +2396,12 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, []string{"max_tokens", "temperature"}) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, []string{"max_tokens", "temperature"}) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }) + responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), schemas.CountTokensRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2533,9 +2425,6 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) response.Model = request.Model - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2670,7 +2559,7 @@ func (provider *AnthropicProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { @@ -2685,9 +2574,6 @@ func (provider *AnthropicProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2751,9 +2637,9 @@ func (provider *AnthropicProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2764,7 +2650,6 @@ func (provider *AnthropicProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -2776,11 +2661,7 @@ func (provider *AnthropicProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2792,9 +2673,9 @@ func (provider *AnthropicProvider) PassthroughStream( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -2844,7 +2725,7 @@ func (provider *AnthropicProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/anthropic/batch.go b/core/providers/anthropic/batch.go index 405738330c..ac4b0940c4 100644 --- a/core/providers/anthropic/batch.go +++ b/core/providers/anthropic/batch.go @@ -3,9 +3,7 @@ package anthropic import ( "time" - providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" ) // Anthropic Batch API Types @@ -129,7 +127,7 @@ func ToBifrostObjectType(anthropicType string) string { } // ToBifrostBatchCreateResponse converts Anthropic batch response to Bifrost batch create response. -func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { expiresAt := parseAnthropicTimestamp(r.ExpiresAt) resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, @@ -140,9 +138,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExpiresAt: &expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -170,7 +166,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem } // ToBifrostBatchRetrieveResponse converts Anthropic batch response to Bifrost batch retrieve response. -func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: ToBifrostObjectType(r.Type), @@ -179,9 +175,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch ResultsURL: r.ResultsURL, CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -228,26 +222,6 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch return resp } -// ParseAnthropicError parses Anthropic error responses for batch operations. -func ParseAnthropicError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { - var errorResp AnthropicError - bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) - if errorResp.Error != nil { - if errorResp.Error.Type != "" { - bifrostErr.Error.Type = &errorResp.Error.Type - } - if errorResp.Error.Message != "" { - bifrostErr.Error.Message = errorResp.Error.Message - } - } - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - } - return bifrostErr -} - // ToAnthropicBatchCreateResponse converts a Bifrost batch create response to Anthropic format. func ToAnthropicBatchCreateResponse(resp *schemas.BifrostBatchCreateResponse) *AnthropicBatchResponse { result := &AnthropicBatchResponse{ diff --git a/core/providers/anthropic/chat.go b/core/providers/anthropic/chat.go index a72ffbe024..5cf221b84e 100644 --- a/core/providers/anthropic/chat.go +++ b/core/providers/anthropic/chat.go @@ -722,12 +722,8 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse(ctx *schemas.Bif // Initialize Bifrost response bifrostResponse := &schemas.BifrostChatResponse{ - ID: response.ID, - Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Anthropic, - }, + ID: response.ID, + Model: response.Model, Created: int(time.Now().Unix()), } diff --git a/core/providers/anthropic/chat_test.go b/core/providers/anthropic/chat_test.go index 04df3bd02f..b73002009b 100644 --- a/core/providers/anthropic/chat_test.go +++ b/core/providers/anthropic/chat_test.go @@ -337,6 +337,115 @@ func TestToAnthropicChatRequest_ToolInputKeyOrderPreservation(t *testing.T) { } } +func TestToBifrostChatResponse_MultipleTextBlocksWithThinking(t *testing.T) { + thinkingText := "Let me reason step by step about this problem." + textBlock1 := "The answer is 42." + textBlock2 := "Here is why that is the case." + signature := "sig_abc123" + + response := &AnthropicMessageResponse{ + ID: "msg_test123", + Type: "message", + Role: "assistant", + Model: "claude-opus-4-6-20250514", + Content: []AnthropicContentBlock{ + { + Type: AnthropicContentBlockTypeThinking, + Thinking: &thinkingText, + Signature: &signature, + }, + { + Type: AnthropicContentBlockTypeText, + Text: &textBlock1, + }, + { + Type: AnthropicContentBlockTypeText, + Text: &textBlock2, + }, + }, + StopReason: "end_turn", + Usage: &AnthropicUsage{ + InputTokens: 100, + OutputTokens: 50, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result := response.ToBifrostChatResponse(ctx) + + if result == nil { + t.Fatal("expected non-nil result") + } + + // Content should be a combined string, not blocks + choice := result.Choices[0] + msg := choice.ChatNonStreamResponseChoice.Message + if msg.Content.ContentBlocks != nil { + t.Error("expected ContentBlocks to be nil (combined into string)") + } + if msg.Content.ContentStr == nil { + t.Fatal("expected ContentStr to be non-nil") + } + + // Combined string: thinking first, then text blocks + expected := thinkingText + "\n\n" + textBlock1 + "\n\n" + textBlock2 + if *msg.Content.ContentStr != expected { + t.Errorf("expected combined content:\n%s\ngot:\n%s", expected, *msg.Content.ContentStr) + } + + // Reasoning field should still have thinking text + if msg.ChatAssistantMessage == nil { + t.Fatal("expected ChatAssistantMessage to be non-nil") + } + if msg.ChatAssistantMessage.Reasoning == nil { + t.Fatal("expected Reasoning to be non-nil") + } + + // ReasoningDetails should have: signature-only thinking entry + content blocks boundary + rd := msg.ChatAssistantMessage.ReasoningDetails + if len(rd) < 2 { + t.Fatalf("expected at least 2 reasoning details entries, got %d", len(rd)) + } + + // First entry: thinking with signature, no text (text was cleared) + if rd[0].Type != schemas.BifrostReasoningDetailsTypeText { + t.Errorf("expected first reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeText, rd[0].Type) + } + if rd[0].Signature == nil || *rd[0].Signature != signature { + t.Error("expected signature to be preserved") + } + if rd[0].Text != nil { + t.Error("expected thinking text to be nil (cleared to avoid duplication)") + } + + // Last entry: content blocks boundary + lastRD := rd[len(rd)-1] + if lastRD.Type != schemas.BifrostReasoningDetailsTypeContentBlocks { + t.Errorf("expected last reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeContentBlocks, lastRD.Type) + } + if lastRD.Text == nil { + t.Fatal("expected content blocks metadata to be non-nil") + } + + // var meta []contentBlockMeta + // if err := json.Unmarshal([]byte(*lastRD.Text), &meta); err != nil { + // t.Fatalf("failed to unmarshal block metadata: %v", err) + // } + // if len(meta) != 3 { + // t.Fatalf("expected 3 block metadata entries, got %d", len(meta)) + // } + // if meta[0].T != "thinking" || meta[0].L != len(thinkingText) { + // t.Errorf("block 0: expected thinking/%d, got %s/%d", len(thinkingText), meta[0].T, meta[0].L) + // } + // if meta[1].T != "text" || meta[1].L != len(textBlock1) { + // t.Errorf("block 1: expected text/%d, got %s/%d", len(textBlock1), meta[1].T, meta[1].L) + // } + // if meta[2].T != "text" || meta[2].L != len(textBlock2) { + // t.Errorf("block 2: expected text/%d, got %s/%d", len(textBlock2), meta[2].T, meta[2].L) + // } +} + func TestToBifrostChatResponse_SingleTextBlockNoThinking(t *testing.T) { // Verify existing behavior: single text block without thinking collapses to string text := "Simple response" diff --git a/core/providers/anthropic/errors.go b/core/providers/anthropic/errors.go index dd1dfaf698..81bbd49d0c 100644 --- a/core/providers/anthropic/errors.go +++ b/core/providers/anthropic/errors.go @@ -54,7 +54,7 @@ func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string { return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) } -func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseAnthropicError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp AnthropicError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Error != nil { @@ -64,10 +64,5 @@ func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMet bifrostErr.Error.Type = &errorResp.Error.Type bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/anthropic/models.go b/core/providers/anthropic/models.go index 3da2f6458b..3815a0244b 100644 --- a/core/providers/anthropic/models.go +++ b/core/providers/anthropic/models.go @@ -1,13 +1,14 @@ package anthropic import ( + "strings" "time" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -19,57 +20,51 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide HasMore: schemas.Ptr(response.HasMore), } - // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination - // If there are more results, set next_page_token to last_id so it can be used in the next request + // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination. + // If there are more results, set next_page_token to last_id for the next request. if response.HasMore && response.LastID != nil { bifrostResponse.NextPageToken = *response.LastID } - includedModels := make(map[string]bool) - for _, model := range response.Data { - modelID := model.ID - if !unfiltered && len(allowedModels) > 0 { - allowed := false - for _, allowedModel := range allowedModels { - if schemas.SameBaseModel(model.ID, allowedModel) { - modelID = allowedModel - allowed = true - break - } - } - if !allowed { - continue - } - } - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, modelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Created: schemas.Ptr(model.CreatedAt.Unix()), - MaxInputTokens: model.MaxInputTokens, - MaxOutputTokens: model.MaxTokens, - ProviderExtra: model.Capabilities, - }) - includedModels[modelID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) - // Backfill allowed models that were not in the response (skip blacklisted; blacklist wins over allow list) - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { continue } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Created: schemas.Ptr(model.CreatedAt.Unix()), + MaxInputTokens: model.MaxInputTokens, + MaxOutputTokens: model.MaxTokens, + ProviderExtra: model.Capabilities, } + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[resolvedKey] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index cb3528aaa0..ad05866e69 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -1445,13 +1445,13 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp if bifrostResp.Response.ID != nil { streamMessage.ID = *bifrostResp.Response.ID } - // Preserve model from Response if available, otherwise use ExtraFields - if bifrostResp.ExtraFields.ModelRequested != "" { - if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { - streamMessage.Model = bifrostResp.Response.Model - } else { - streamMessage.Model = bifrostResp.ExtraFields.ModelRequested - } + // Prefer Response.Model, then ResolvedModelUsed, then OriginalModelRequested + if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { + streamMessage.Model = bifrostResp.Response.Model + } else if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + streamMessage.Model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if bifrostResp.ExtraFields.OriginalModelRequested != "" { + streamMessage.Model = bifrostResp.ExtraFields.OriginalModelRequested } streamResp.Message = streamMessage } diff --git a/core/providers/anthropic/text.go b/core/providers/anthropic/text.go index 3228ad49f6..39a700499b 100644 --- a/core/providers/anthropic/text.go +++ b/core/providers/anthropic/text.go @@ -103,10 +103,6 @@ func (response *AnthropicTextResponse) ToBifrostTextCompletionResponse() *schema TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, }, Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Anthropic, - }, } } diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 3c11d076d9..d2a636de8d 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -1564,7 +1564,7 @@ type AnthropicFileDeleteResponse struct { } // ToBifrostFileUploadResponse converts an Anthropic file response to Bifrost file upload response. -func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Type, @@ -1575,9 +1575,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1593,7 +1591,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas } // ToBifrostFileRetrieveResponse converts an Anthropic file response to Bifrost file retrieve response. -func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { +func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { resp := &schemas.BifrostFileRetrieveResponse{ ID: r.ID, Object: r.Type, @@ -1604,9 +1602,7 @@ func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1636,4 +1632,4 @@ func parseAnthropicFileTimestamp(timestamp string) int64 { // AnthropicCountTokensResponse models the payload returned by Anthropic's count tokens endpoint. type AnthropicCountTokensResponse struct { InputTokens int `json:"input_tokens"` -} +} \ No newline at end of file diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index dbde522d6f..6b1d43f68a 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -661,7 +661,7 @@ func setEffortOnOutputConfig(req *AnthropicMessageRequest, effort string) { req.OutputConfig.Effort = &effort } -func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, providerName schemas.ModelProvider, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader in completeRequest/ // setAnthropicRequestBody — skip all body building here (matches CheckContextAndGetRequestBody). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -681,7 +681,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi _, model := schemas.ParseModelString(modelStr, schemas.Anthropic) jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", model) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -693,20 +693,20 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi } jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", defaultMaxTokens) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Strip auto-injectable server-side tools to prevent conflicts with API auto-injection jsonBody, err = StripAutoInjectableTools(jsonBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Sanitize raw-body fields the target provider does not support. // Behavioural parity with StripUnsupportedAnthropicFields on the typed path. @@ -732,17 +732,17 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Convert request to Anthropic format reqBody, convErr := ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Anthropic) if isStreaming { @@ -751,7 +751,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -760,14 +760,14 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Use MergeExtraParamsIntoJSON which preserves key order jsonBody, err = providerUtils.MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Remove excluded fields after merging (using sjson to preserve order) for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else if len(excludeFields) > 0 { @@ -775,7 +775,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 9d7c3063d7..323d13584e 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -100,7 +100,7 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -109,11 +109,11 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, Scopes: scopes, }) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -138,16 +138,16 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -206,10 +206,8 @@ func (provider *AzureProvider) completeRequest( jsonData []byte, path string, key schemas.Key, - deployment string, model string, - requestType schemas.RequestType, -) ([]byte, string, time.Duration, map[string]string, *schemas.BifrostError) { +) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -222,7 +220,7 @@ func (provider *AzureProvider) completeRequest( }() var url string - isAnthropicModel := schemas.IsAnthropicModel(deployment) + isAnthropicModel := schemas.IsAnthropicModel(model) // Set any extra headers from network config. // For Anthropic models, exclude anthropic-beta — it is merged and filtered explicitly below. @@ -237,7 +235,7 @@ func (provider *AzureProvider) completeRequest( // Get authentication headers authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, isAnthropicModel) if bifrostErr != nil { - return nil, deployment, 0, nil, bifrostErr + return nil, 0, nil, bifrostErr } // Apply headers to request @@ -247,7 +245,7 @@ func (provider *AzureProvider) completeRequest( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, deployment, 0, nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, 0, nil, providerUtils.NewConfigurationError("endpoint not set") } if isAnthropicModel { @@ -282,7 +280,7 @@ func (provider *AzureProvider) completeRequest( latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp) defer wait() if bifrostErr != nil { - return nil, deployment, latency, nil, bifrostErr + return nil, latency, nil, bifrostErr } // Extract provider response headers before body is copied — do this before status check @@ -292,33 +290,25 @@ func (provider *AzureProvider) completeRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, deployment, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, requestType, provider.GetProviderKey(), model) + rawErrBody := append([]byte(nil), resp.Body()...) + return rawErrBody, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { - return nil, deployment, latency, providerResponseHeaders, decodeErr + return nil, latency, providerResponseHeaders, decodeErr } if isLargeResp { respOwned = false - return nil, deployment, latency, providerResponseHeaders, nil + return nil, latency, providerResponseHeaders, nil } - return body, deployment, latency, providerResponseHeaders, nil + return body, latency, providerResponseHeaders, nil } // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - // Validate Azure key configuration - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", schemas.Azure) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", schemas.Azure) - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -359,12 +349,12 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -379,9 +369,9 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert to Bifrost response - response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, key.BlacklistedModels, request.Unfiltered) + response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil) } response.ExtraFields.Latency = latency.Milliseconds() @@ -415,35 +405,23 @@ func (provider *AzureProvider) ListModels(ctx *schemas.BifrostContext, keys []sc // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized OpenAI text converter (Azure is OpenAI-compatible) jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAITextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/completions", deployment), + fmt.Sprintf("openai/deployments/%s/completions", request.Model), key, - deployment, request.Model, - schemas.TextCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -457,10 +435,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -474,10 +448,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -498,21 +468,12 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s // It formats the request, sends it to Azure, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -520,11 +481,6 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, return nil, err } - customPostResponseConverter := func(response *schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, @@ -538,7 +494,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, nil, postHookRunner, nil, - customPostResponseConverter, + nil, provider.logger, ) } @@ -547,26 +503,16 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { reqBody, err := anthropic.ToAnthropicChatRequest(ctx, request) if err != nil { return nil, err } if reqBody != nil { - reqBody.Model = deployment // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } @@ -574,27 +520,24 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } else { return openai.ToOpenAIChatRequest(ctx, request), nil } - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { - path = fmt.Sprintf("openai/deployments/%s/chat/completions", deployment) + path = fmt.Sprintf("openai/deployments/%s/chat/completions", request.Model) } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ChatCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -608,10 +551,6 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -622,7 +561,7 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -637,12 +576,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ChatCompletionRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -662,22 +597,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -694,14 +615,12 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, return nil, err } if reqBody != nil { - reqBody.Model = deployment reqBody.Stream = schemas.Ptr(true) // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -719,13 +638,8 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -736,7 +650,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Use shared streaming logic from OpenAI return openai.HandleOpenAIChatCompletionStreaming( @@ -754,7 +668,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, nil, nil, nil, - postResponseConverter, + nil, provider.logger, ) } @@ -764,51 +678,36 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - var jsonData []byte var bifrostErr *schemas.BifrostError - if schemas.IsAnthropicModel(deployment) { - jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), false) + if schemas.IsAnthropicModel(request.Model) { + jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false) } else { jsonData, bifrostErr = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { reqBody := openai.ToOpenAIResponsesRequest(request) - if reqBody != nil { - reqBody.Model = deployment - } return reqBody, nil - }, - provider.GetProviderKey()) + }) } if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { path = "openai/v1/responses" } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ResponsesRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -822,10 +721,6 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -836,7 +731,7 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -851,12 +746,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ResponsesRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -873,22 +764,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema // ResponsesStream performs a streaming responses request to Azure's API. func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -896,7 +773,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint.GetValue()) - jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), true) + jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true) if bifrostErr != nil { return nil, bifrostErr } @@ -914,13 +791,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -929,11 +801,6 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post } url = fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint.GetValue()) - postRequestConverter := func(req *openai.OpenAIResponsesRequest) *openai.OpenAIResponsesRequest { - req.Model = deployment - return req - } - // Use shared streaming logic from OpenAI return openai.HandleOpenAIResponsesStreaming( ctx, @@ -948,8 +815,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -959,35 +826,23 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAIEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/embeddings", deployment), + fmt.Sprintf("openai/deployments/%s/embeddings", request.Model), key, - deployment, request.Model, - schemas.EmbeddingRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1001,10 +856,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1019,12 +870,8 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.EmbeddingRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1041,15 +888,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema // Speech is not supported by the Azure provider. func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1057,10 +895,10 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAISpeechRequest( ctx, @@ -1080,9 +918,6 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1099,15 +934,6 @@ func (provider *AzureProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, // SpeechStream handles streaming for speech synthesis with Azure. // Azure sends raw binary audio bytes in SSE format, unlike OpenAI which sends JSON. func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1118,7 +944,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1158,11 +984,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo reqBody := openai.ToOpenAISpeechRequest(request) if reqBody != nil { reqBody.StreamFormat = schemas.Ptr("sse") - reqBody.Model = deployment // Replace model with deployment } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1186,9 +1010,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(requestErr, fasthttp.ErrTimeout) || errors.Is(requestErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1197,7 +1021,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Create response channel @@ -1210,9 +1034,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1313,11 +1137,6 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo var bifrostErr schemas.BifrostError if errParseErr := sonic.Unmarshal(audioData, &bifrostErr); errParseErr == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -1339,12 +1158,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Set extra fields for the response response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -1373,7 +1188,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // a fake "done" response with truncated audio. ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1386,12 +1201,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo finalResponse := schemas.BifrostSpeechStreamResponse{ Type: schemas.SpeechStreamResponseTypeDone, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1414,21 +1225,12 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Transcription is not supported by the Azure provider. func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAITranscriptionRequest( ctx, @@ -1447,9 +1249,6 @@ func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key sc return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1463,16 +1262,6 @@ func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Returns a BifrostResponse containing the bifrost response or an error if the request fails. func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1480,13 +1269,13 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } response, err := openai.HandleOpenAIImageGenerationRequest( ctx, provider.client, - fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()), + fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()), request, key, provider.networkConfig.ExtraHeaders, @@ -1499,9 +1288,6 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1514,18 +1300,6 @@ func (provider *AzureProvider) ImageGenerationStream( key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - // - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1533,17 +1307,10 @@ func (provider *AzureProvider) ImageGenerationStream( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1564,7 +1331,7 @@ func (provider *AzureProvider) ImageGenerationStream( postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1572,16 +1339,6 @@ func (provider *AzureProvider) ImageGenerationStream( // ImageEdit performs an image edit request to Azure's API. func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1589,10 +1346,10 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAIImageEditRequest( ctx, provider.client, @@ -1609,24 +1366,11 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } // ImageEditStream performs a streaming image edit request to Azure's API. func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1634,17 +1378,10 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1665,7 +1402,7 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1679,30 +1416,19 @@ func (provider *AzureProvider) ImageVariation(ctx *schemas.BifrostContext, key s // VideoGeneration creates a video using Azure's OpenAI-compatible Sora API. // This delegates to the OpenAI handler with Azure-specific URL and authentication. func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, bifrostErr := provider.getModelDeployment(key, request.Model) - if bifrostErr != nil { - return nil, bifrostErr - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL for OpenAI-compatible video generation endpoint url := fmt.Sprintf("%s/openai/v1/videos", endpoint) - requestCopy := *request - requestCopy.Model = deployment response, bifrostErr := openai.HandleOpenAIVideoGenerationRequest( ctx, provider.client, url, - &requestCopy, + request, key, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), @@ -1714,27 +1440,20 @@ func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, nil } // VideoRetrieve retrieves the status of a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false) @@ -1760,20 +1479,16 @@ func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key sc // VideoDownload downloads video content from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Create request @@ -1809,13 +1524,12 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -1831,9 +1545,7 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc Content: append([]byte(nil), body...), ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1842,20 +1554,16 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // VideoDelete deletes a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1882,13 +1590,9 @@ func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key sche // VideoList lists videos from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoList(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1918,64 +1622,14 @@ func (provider *AzureProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.K return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey()) } -// validateKeyConfig validates the key configuration. -// It checks if the key config is set, the endpoint is set, and the deployments are set. -// Returns an error if any of the checks fail. -func (provider *AzureProvider) validateKeyConfig(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments == nil { - return providerUtils.NewConfigurationError("deployments not set", provider.GetProviderKey()) - } - - return nil -} - -// validateKeyConfigForFiles validates key config for file/batch APIs, which only -// require a configured Azure endpoint (no per-model deployments needed). -func (provider *AzureProvider) validateKeyConfigForFiles(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - return nil -} - -func (provider *AzureProvider) getModelDeployment(key schemas.Key, model string) (string, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return "", providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments != nil { - if deployment, ok := key.AzureKeyConfig.Deployments[model]; ok { - return deployment, nil - } - } - return "", providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", model), provider.GetProviderKey()) -} - // FileUpload uploads a file to Azure OpenAI. func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Get API version @@ -1990,7 +1644,7 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add file field @@ -2000,14 +1654,14 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2044,13 +1698,12 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp openai.OpenAIFileResponse @@ -2061,17 +1714,15 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem return nil, bifrostErr } - return openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided Azure keys and aggregates results. // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2080,7 +1731,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2091,18 +1742,9 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2150,13 +1792,12 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIFileListResponse @@ -2191,9 +1832,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2208,7 +1847,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2216,11 +1855,6 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2263,8 +1897,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2276,7 +1909,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2302,14 +1935,12 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // FileDelete deletes a file from Azure OpenAI by trying each key until successful. func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2317,11 +1948,6 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2364,8 +1990,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2381,9 +2006,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2393,7 +2016,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2416,9 +2039,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2438,24 +2059,17 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // FileContent downloads file content from Azure OpenAI by trying each key until found. func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation") } var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2497,8 +2111,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2510,7 +2123,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2530,9 +2143,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2543,12 +2154,6 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // BatchCreate creates a new batch job on Azure OpenAI. // Azure Batch API uses the same format as OpenAI but with Azure-specific URL patterns. func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -2556,12 +2161,11 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Convert inline requests to JSONL format jsonlData, err := openai.ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" uploadResp, bifrostErr := provider.FileUpload(ctx, key, &schemas.BifrostFileUploadRequest{ - Provider: schemas.Azure, File: jsonlData, Filename: "batch_requests.jsonl", Purpose: "batch", @@ -2575,7 +2179,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil) } // Get API version @@ -2622,7 +2226,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -2635,13 +2239,12 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } var openAIResp openai.OpenAIBatchResponse @@ -2652,25 +2255,24 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs from all provided Azure keys and aggregates results. // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation") } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2681,18 +2283,9 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2738,13 +2331,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIBatchListResponse @@ -2757,7 +2349,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -2770,9 +2362,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2784,14 +2374,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // BatchRetrieve retrieves a specific batch job from Azure OpenAI by trying each key until found. func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2799,11 +2387,6 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2846,8 +2429,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2859,7 +2441,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2877,8 +2459,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -2887,14 +2468,12 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // BatchCancel cancels a batch job on Azure OpenAI by trying each key until successful. func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2902,11 +2481,6 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2949,8 +2523,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2962,7 +2535,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2987,9 +2560,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3023,8 +2594,6 @@ func (provider *AzureProvider) BatchDelete(ctx *schemas.BifrostContext, keys []s // BatchResults retrieves batch results from Azure OpenAI by trying each key until successful. // For Azure (like OpenAI), batch results are obtained by downloading the output_file_id. func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3035,7 +2604,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file content (using all keys) @@ -3064,9 +2633,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } @@ -3133,14 +2700,6 @@ func (provider *AzureProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) fasthttpReq := fasthttp.AcquireRequest() @@ -3177,7 +2736,7 @@ func (provider *AzureProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -3193,9 +2752,6 @@ func (provider *AzureProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3213,14 +2769,6 @@ func (provider *AzureProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) fasthttpReq := fasthttp.AcquireRequest() @@ -3267,9 +2815,9 @@ func (provider *AzureProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -3277,21 +2825,13 @@ func (provider *AzureProvider) PassthroughStream( rawBodyStream := resp.BodyStream() if rawBodyStream == nil { providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError( - "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + return nil, providerUtils.NewBifrostOperationError("provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body")) } bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx)) stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -3302,9 +2842,9 @@ func (provider *AzureProvider) PassthroughStream( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -3353,7 +2893,7 @@ func (provider *AzureProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/azure/azure_test.go b/core/providers/azure/azure_test.go index 5727cb343d..32ec72040b 100644 --- a/core/providers/azure/azure_test.go +++ b/core/providers/azure/azure_test.go @@ -26,12 +26,12 @@ func TestAzure(t *testing.T) { testConfig := llmtests.ComprehensiveTestConfig{ Provider: schemas.Azure, - ChatModel: "gpt-4o-backup", - PromptCachingModel: "gpt-4o-backup", + ChatModel: "gpt-4o", + PromptCachingModel: "gpt-4o", VisionModel: "gpt-4o", ChatAudioModel: "gpt-4o-mini-audio-preview", Fallbacks: []schemas.Fallback{ - {Provider: schemas.Azure, Model: "gpt-4o-backup"}, + {Provider: schemas.Azure, Model: "gpt-4o"}, }, TextModel: "", // Azure doesn't support text completion in newer models EmbeddingModel: "text-embedding-ada-002", @@ -60,7 +60,7 @@ func TestAzure(t *testing.T) { Embedding: true, ListModels: true, Reasoning: true, - ChatAudio: true, + ChatAudio: false, Transcription: false, // Disabled for azure because of 3 calls/minute quota TranscriptionStream: false, // Not properly supported yet by Azure SpeechSynthesis: false, // Disabled for azure because of 3 calls/minute quota diff --git a/core/providers/azure/files.go b/core/providers/azure/files.go index 4c7ce174f8..d008b146de 100644 --- a/core/providers/azure/files.go +++ b/core/providers/azure/files.go @@ -24,7 +24,7 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -33,11 +33,11 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R Scopes: scopes, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("Azure access token is empty", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("azure access token is empty", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -68,16 +68,16 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -110,9 +110,7 @@ func (r *AzureFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mod StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/azure/models.go b/core/providers/azure/models.go index d5ff81229a..5daca3836d 100644 --- a/core/providers/azure/models.go +++ b/core/providers/azure/models.go @@ -1,65 +1,13 @@ package azure import ( - "slices" + "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// findMatchingAllowedModel finds a matching item in a slice, considering both -// exact match and base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact match - if slices.Contains(slice, value) { - return value - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Return the item from slice (not value) to use the actual name from allowedModels - for _, item := range slice { - if schemas.SameBaseModel(item, value) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and base model matches (ignoring version suffixes). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check exact match first (by alias/key) - if deployment, ok := deployments[modelID]; ok { - return deployment, modelID - } - - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == modelID { - return depValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - for aliasKey, deploymentValue := range deployments { - // Check if modelID's base matches deploymentValue's base - if schemas.SameBaseModel(deploymentValue, modelID) { - return deploymentValue, aliasKey - } - // Also check if modelID's base matches alias's base (for cases where alias is used as deployment) - if schemas.SameBaseModel(aliasKey, modelID) { - return deploymentValue, aliasKey - } - } - return "", "" -} - -func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -68,111 +16,36 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - modelID := model.ID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering base model matches (ignoring version suffixes) - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) - } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ID, modelID, deploymentAlias, matchedAllowedModel) { - continue - } - - modelEntry := schemas.Model{ - ID: string(schemas.Azure) + "/" + modelID, - Created: schemas.Ptr(model.CreatedAt), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Azure) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true - } else { - includedModels[modelID] = true - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Azure, + MatchFns: providerUtils.DefaultMatchFns(), } - - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[alias] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[alias] = true - } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Azure) + "/" + result.ResolvedID, + Created: schemas.Ptr(model.CreatedAt), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/azure/utils.go b/core/providers/azure/utils.go index 49d1db8de3..20f216c19e 100644 --- a/core/providers/azure/utils.go +++ b/core/providers/azure/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader — skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -27,24 +27,24 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Replace model with deployment jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Delete fallbacks field jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { @@ -52,10 +52,10 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s request.Model = deployment reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } if isStreaming { @@ -68,7 +68,7 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes, preserving field order jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 2a181068a2..6b7cf700f0 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -26,7 +26,6 @@ import ( "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/maximhq/bifrost/core/providers/anthropic" - "github.com/maximhq/bifrost/core/providers/cohere" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -222,7 +221,7 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -245,10 +244,10 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -349,7 +348,7 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros if key.Value.GetValue() != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -370,10 +369,10 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros } var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } var opErr *net.OpError var dnsErr *net.DNSError @@ -420,15 +419,9 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros // makeStreamingRequest creates a streaming request to Bedrock's API. // It formats the request, sends it to Bedrock, and returns the response. // Returns the response body and an error if the request fails. -func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, "", providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - +func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, *schemas.BifrostError) { // Format the path with proper model identifier for streaming - path, deployment := provider.getModelPath(action, model, key) + path := provider.getModelPath(action, model, key) region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { @@ -438,7 +431,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Create HTTP request for streaming req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonData)) if reqErr != nil { - return nil, deployment, providerUtils.NewBifrostOperationError("error creating request", reqErr, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", reqErr) } // Set any extra headers from network config @@ -457,8 +450,8 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex } else { req.Header.Set("Accept", "application/vnd.amazon.eventstream") // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { - return nil, deployment, err + if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { + return nil, err } } @@ -466,7 +459,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex resp, respErr := provider.client.Do(req) if respErr != nil { if errors.Is(respErr, context.Canceled) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Type: schemas.Ptr(schemas.RequestCancelled), @@ -478,35 +471,29 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(respErr, &netErr) && netErr.Timeout() { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } if errors.Is(respErr, http.ErrHandlerTimeout) || errors.Is(respErr, context.DeadlineExceeded) { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError var dnsErr *net.DNSError if errors.As(respErr, &opErr) || errors.As(respErr, &dnsErr) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderNetworkError, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderDoRequest, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } @@ -517,10 +504,10 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - return nil, deployment, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) + return nil, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) } - return resp, deployment, nil + return resp, nil } // signAWSRequest signs an HTTP request using AWS Signature Version 4. @@ -537,7 +524,6 @@ func signAWSRequest( externalID *schemas.EnvVar, sessionName *schemas.EnvVar, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Set required headers before signing (only if not already set) if req.Header.Get("Content-Type") == "" { @@ -552,7 +538,7 @@ func signAWSRequest( if req.Body != nil { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return providerUtils.NewBifrostOperationError("error reading request body", err, providerName) + return providerUtils.NewBifrostOperationError("error reading request body", err) } // Restore the body for subsequent reads req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -594,11 +580,10 @@ func signAWSRequest( ) } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } if roleARN != nil && roleARN.GetValue() != "" { - extID := "" if externalID != nil { extID = externalID.GetValue() @@ -653,12 +638,12 @@ func signAWSRequest( // Get credentials creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } // Sign the request with AWS Signature V4 if err := signer.SignHTTP(ctx, creds, req, bodyHash, service, region, time.Now()); err != nil { - return providerUtils.NewBifrostOperationError("failed to sign request", err, providerName) + return providerUtils.NewBifrostOperationError("failed to sign request", err) } return nil @@ -668,13 +653,7 @@ func signAWSRequest( // It retrieves all foundation models available in Amazon Bedrock for a specific key. func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - config := key.BedrockKeyConfig - region := DefaultBedrockRegion if config.Region != nil && config.Region.GetValue() != "" { region = config.Region.GetValue() @@ -721,7 +700,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, err } } @@ -744,10 +723,10 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -795,9 +774,9 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Convert to Bifrost response - response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, key.BlacklistedModels, request.Unfiltered) + response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() @@ -838,24 +817,17 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) body, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -867,29 +839,25 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key // Handle model-specific response conversion var bifrostResponse *schemas.BifrostTextCompletionResponse switch { - case schemas.IsAnthropicModel(deployment): + case schemas.IsAnthropicModel(request.Model): var response BedrockAnthropicTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() - case schemas.IsMistralModel(deployment): + case schemas.IsMistralModel(request.Model): var response BedrockMistralTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() default: - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model)) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -902,7 +870,7 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err) } bifrostResponse.ExtraFields.RawResponse = rawResponse } @@ -920,22 +888,17 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -952,9 +915,9 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1000,14 +963,9 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1038,15 +996,10 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: fmt.Sprintf("%s stream %s: %s", providerName, excType, errMsg), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1058,18 +1011,14 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex } if err := sonic.Unmarshal(message.Payload, &chunkPayload); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } // Create BifrostStreamChunk response containing the raw model-specific JSON chunk textResponse := &schemas.BifrostTextCompletionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - Latency: time.Since(startTime).Milliseconds(), + Latency: time.Since(startTime).Milliseconds(), // Pass the raw JSON string from the chunk bytes RawResponse: string(chunkPayload.Bytes), }, @@ -1091,26 +1040,19 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1127,13 +1069,13 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostChatResponse(ctx, request.Model) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Override finish reason for structured output @@ -1147,10 +1089,6 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1177,21 +1115,17 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } - - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1202,15 +1136,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds) - // Start streaming in a goroutine go func() { defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1266,7 +1199,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) + provider.logger.Warn("Error decoding EventStream message: %v", err) // Transport-level errors (stale/closed connection, unexpected EOF) are retryable. // Use IsBifrostError:false so the retry gate in executeRequestWithRetries can retry. if isStreamTransportError(err) { @@ -1276,14 +1209,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1299,7 +1227,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) + err := fmt.Errorf("stream %s: %s", excType, errMsg) // Retryable AWS exceptions must not set IsBifrostError:true — that would // bypass the retry gate in executeRequestWithRetries. Instead emit // IsBifrostError:false with the equivalent HTTP status code so the existing @@ -1311,14 +1239,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1328,7 +1251,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1407,12 +1330,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1429,11 +1348,6 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response, bifrostErr, _ := streamEvent.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1442,12 +1356,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response.ID = id response.Model = request.Model response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1466,8 +1376,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } // Send final response - response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model) - response.ExtraFields.ModelDeployment = deployment + response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, request.Model, 0) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonData) @@ -1488,26 +1397,19 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1524,22 +1426,18 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostResponsesResponse(ctx) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = deployment + bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1567,20 +1465,17 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1597,9 +1492,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1620,7 +1515,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Create stream state for stateful conversions streamState := acquireBedrockResponsesStreamState() - streamState.Model = &deployment + streamState.Model = &request.Model defer releaseBedrockResponsesStreamState(streamState) // Check for structured output mode - if set, we need to intercept tool calls @@ -1636,7 +1531,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po lastChunkTime := startTime decoder := eventstream.NewDecoder() payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer - for { // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { @@ -1654,12 +1548,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po finalResponses := FinalizeBedrockStream(streamState, chunkIndex, usage) for i, finalResponse := range finalResponses { finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1682,7 +1572,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) + provider.logger.Warn("Error decoding EventStream message: %v", err) // Transport-level errors (stale/closed connection, unexpected EOF) are retryable. // Use IsBifrostError:false so the retry gate in executeRequestWithRetries can retry. if isStreamTransportError(err) { @@ -1692,14 +1582,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1715,7 +1600,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) + err := fmt.Errorf("stream %s: %s", excType, errMsg) // Retryable AWS exceptions must not set IsBifrostError:true — that would // bypass the retry gate in executeRequestWithRetries. Instead emit // IsBifrostError:false with the equivalent HTTP status code so the existing @@ -1727,14 +1612,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1744,7 +1624,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1800,12 +1680,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po SequenceNumber: chunkIndex, Delta: &content, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1822,11 +1698,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po responses, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1834,12 +1705,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po for _, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1865,15 +1732,10 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Determine model type modelType, err := DetermineEmbeddingModelType(request.Model) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } // Convert request and execute based on model type @@ -1882,7 +1744,6 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string var jsonData []byte switch modelType { @@ -1892,12 +1753,11 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTitanEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) case "cohere": @@ -1906,16 +1766,15 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockCohereEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) default: - return nil, providerUtils.NewConfigurationError("unsupported embedding model type", providerName) + return nil, providerUtils.NewConfigurationError("unsupported embedding model type") } if providerResponseHeaders != nil { @@ -1924,32 +1783,40 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche if bifrostError != nil { return nil, providerUtils.EnrichError(ctx, bifrostError, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - // Parse response based on model type var bifrostResponse *schemas.BifrostEmbeddingResponse switch modelType { case "titan": var titanResp BedrockTitanEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = titanResp.ToBifrostEmbeddingResponse() bifrostResponse.Model = request.Model case "cohere": - var cohereResp cohere.CohereEmbeddingResponse + var cohereResp BedrockCohereEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + converted, convErr := cohereResp.ToBifrostEmbeddingResponse() + if convErr != nil { + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", convErr), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse = cohereResp.ToBifrostEmbeddingResponse() + bifrostResponse = converted bifrostResponse.Model = request.Model + // For embeddings_by_type responses preserve the raw Bedrock payload so the + // invoke-endpoint converter can return all encoding variants verbatim, since + // the internal BifrostEmbeddingResponse only has float32 and string fields. + if cohereResp.ResponseType == "embeddings_by_type" { + var rawResponseData interface{} + if err := sonic.Unmarshal(rawResponse, &rawResponseData); err == nil { + bifrostResponse.ExtraFields.RawResponse = rawResponseData + } + } } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1975,26 +1842,16 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - - deployment := strings.TrimSpace(resolveBedrockDeployment(request.Model, key)) - if deployment == "" { - return nil, providerUtils.NewConfigurationError("bedrock rerank model is empty", providerName) - } - if !strings.HasPrefix(deployment, "arn:") { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", deployment), providerName) + if !strings.HasPrefix(request.Model, "arn:") { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", request.Model)) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToBedrockRerankRequest(request, deployment) + return ToBedrockRerankRequest(request, request.Model) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2018,10 +1875,6 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas bifrostResponse := response.ToBifrostRerankResponse(request.Documents, returnDocuments) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2056,37 +1909,34 @@ func (provider *BedrockProvider) TranscriptionStream(ctx *schemas.BifrostContext } // ImageGeneration generates images using Amazon Bedrock. -// Supports Titan Image Generator v1, Nova Canvas v1, and Titan Image Generator v2. +// Supports Titan Image Generator v1, Nova Canvas v1, Titan Image Generator v2, and Stability AI models. // Returns a BifrostImageGenerationResponse containing the generated images and any error that occurred. func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var rawResponse []byte var jsonData []byte var bifrostError *schemas.BifrostError var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string + + path = provider.getModelPath("invoke", request.Model, key) jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { + if isStabilityAIModel(request.Model) { + return ToStabilityAIImageGenerationRequest(request) + } return ToBedrockImageGenerationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2099,19 +1949,15 @@ func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, ke var bifrostResponse *schemas.BifrostImageGenerationResponse var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2136,33 +1982,36 @@ func (provider *BedrockProvider) ImageGenerationStream(ctx *schemas.BifrostConte } // ImageEdit performs image editing using Amazon Bedrock. -// Supports Titan Image Generator v1, Nova Canvas v1, and Titan Image Generator v2. -// Supports three edit types: INPAINTING, OUTPAINTING, and BACKGROUND_REMOVAL. +// Supports Titan Image Generator v1, Nova Canvas v1, Titan Image Generator v2 (three edit types: +// INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL), and Stability AI edit models (inpaint, outpaint, +// recolor, search-replace, erase-object, remove-bg, control-sketch, control-structure, style-guide, +// style-transfer, upscale-creative, upscale-conservative, upscale-fast). // Returns a BifrostImageGenerationResponse containing the edited images and any error that occurred. func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ImageEditRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError + // Stability AI routing and task-type inference use the actual model ID. + path := provider.getModelPath("invoke", request.Model, key) + jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageEditRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { + if isStabilityAIModel(request.Model) { + return ToStabilityAIImageEditRequest(request, request.Model) + } + return ToBedrockImageEditRequest(request) + }) if bifrostError != nil { return nil, bifrostError } // Make API request (same URL as image generation) - path, deployment := provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2174,20 +2023,16 @@ func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key sche // Parse response (reuse BedrockImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2219,11 +2064,6 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError @@ -2232,14 +2072,13 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageVariationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } // Make API request (same URL as image generation) - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2251,20 +2090,16 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key // Parse response (reuse BedrockImageGenerationResponse and ToBifrostImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageVariationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2323,13 +2158,6 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is is missing in file upload request") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2351,7 +2179,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch if s3Bucket == "" { provider.logger.Error("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)") - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } // Parse bucket name and optional prefix from s3Bucket (could be "bucket-name" or "s3://bucket-name/prefix/") @@ -2385,14 +2213,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch httpReq, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(request.File)) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } httpReq.Header.Set("Content-Type", "application/octet-stream") httpReq.ContentLength = int64(len(request.File)) // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("error signing request: %s", err.Error.Message) return nil, err } @@ -2412,14 +2240,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { body, _ := io.ReadAll(resp.Body) provider.logger.Error("s3 upload failed: %d", resp.StatusCode) - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Return S3 URI as the file ID @@ -2436,9 +2264,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch StorageBackend: schemas.FileStorageS3, StorageURI: s3URI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2451,8 +2277,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2474,7 +2298,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc } if s3Bucket == "" { - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } bucketName, bucketPrefix := parseS3URI(s3Bucket) @@ -2485,7 +2309,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2496,10 +2320,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2526,14 +2346,11 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request for S3 - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); bifrostErr != nil { return nil, bifrostErr } @@ -2552,23 +2369,23 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Parse S3 ListObjectsV2 XML response var listResp S3ListObjectsResponse if err := parseS3ListResponse(body, &listResp); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err) } // Convert files to Bifrost format @@ -2600,9 +2417,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2618,25 +2433,18 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2648,12 +2456,12 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodHead, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2673,13 +2481,13 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, nil, nil) continue } @@ -2709,9 +2517,7 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys StorageBackend: schemas.FileStorageS3, StorageURI: request.FileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2725,25 +2531,18 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2755,12 +2554,12 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] httpReq, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2780,7 +2579,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } @@ -2788,7 +2587,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } @@ -2799,9 +2598,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2815,25 +2612,18 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2845,12 +2635,12 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2870,21 +2660,21 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err) continue } @@ -2898,9 +2688,7 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ Content: body, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2915,13 +2703,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is not provided") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Require RoleArn in extra params roleArn := "" // First we will honor the role_arn coming from the client side if present @@ -2932,14 +2713,14 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // If its empty then we will honor the role_arn from the key config if roleArn == "" { - if key.BedrockKeyConfig.ARN != nil { - roleArn = key.BedrockKeyConfig.ARN.GetValue() + if key.BedrockKeyConfig.RoleARN != nil { + roleArn = key.BedrockKeyConfig.RoleARN.GetValue() } } // And if still we don't get role ARN if roleArn == "" { provider.logger.Error("role_arn is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil) } // Get output S3 URI from extra params outputS3Uri := "" @@ -2950,24 +2731,12 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } if outputS3Uri == "" { provider.logger.Error("output_s3_uri is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil) } if request.Model == nil { provider.logger.Error("model is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil, providerName) - } - - // Get model ID - - var modelID *string - if key.BedrockKeyConfig.Deployments != nil && request.Model != nil { - if deployment, ok := key.BedrockKeyConfig.Deployments[*request.Model]; ok { - modelID = schemas.Ptr(deployment) - } - } - if modelID == nil { - modelID = request.Model + return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil) } // Generate job name @@ -2995,9 +2764,9 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // Convert inline requests to Bedrock JSONL format - jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, modelID) + jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, request.Model) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Generate S3 key for the input file @@ -3017,7 +2786,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc bucket, s3Key, jsonlData, - providerName, ); bifrostErr != nil { return nil, bifrostErr } @@ -3028,13 +2796,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc // Validate that we have an input file ID (either provided or uploaded) if inputFileID == "" { provider.logger.Error("either input_file_id (S3 URI) or requests array is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil) } // Build request bedrockReq := &BedrockBatchJobRequest{ JobName: jobName, - ModelID: modelID, + ModelID: request.Model, RoleArn: roleArn, InputDataConfig: BedrockInputDataConfig{ S3InputDataConfig: BedrockS3InputDataConfig{ @@ -3059,7 +2827,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(bedrockReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } sendBackRawRequest := provider.sendBackRawRequest @@ -3074,11 +2842,11 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc reqURL := fmt.Sprintf("https://bedrock.%s.amazonaws.com/model-invocation-job", region) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewBuffer(jsonData)) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } @@ -3097,13 +2865,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc }, }, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { @@ -3112,7 +2880,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, sendBackRawRequest, sendBackRawResponse) } // AWS CreateModelInvocationJob only returns jobArn, not status or other details. @@ -3129,9 +2897,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc InputFileID: inputFileID, Status: schemas.BatchStatusValidating, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3144,9 +2910,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc Status: retrieveResp.Status, CreatedAt: retrieveResp.CreatedAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3164,12 +2928,10 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - // Initialize serial pagination helper (Bedrock uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3180,17 +2942,9 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3213,11 +2967,11 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); bifrostErr != nil { return nil, bifrostErr } @@ -3236,13 +2990,13 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { @@ -3251,7 +3005,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s var bedrockResp BedrockBatchJobListResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert batches to Bifrost format @@ -3296,9 +3050,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3338,7 +3090,7 @@ func (provider *BedrockProvider) fetchBatchManifest(ctx *schemas.BifrostContext, } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("failed to sign manifest request: %v", err) return nil } @@ -3376,19 +3128,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3400,12 +3145,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } @@ -3425,14 +3170,14 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3443,7 +3188,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) continue } @@ -3462,9 +3207,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys Status: ToBifrostBatchStatus(bedrockResp.Status), Metadata: metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3520,19 +3263,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3544,12 +3280,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } @@ -3569,14 +3305,14 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3599,9 +3335,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: schemas.BatchStatusCancelling, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: totalLatency.Milliseconds(), + Latency: totalLatency.Milliseconds(), }, }, nil } @@ -3611,9 +3345,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: retrieveResp.Status, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3634,8 +3366,6 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output S3 URI prefix (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3646,7 +3376,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil) } outputS3URI := *batchResp.OutputFileID @@ -3688,7 +3418,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys if directErr != nil { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to access batch results at %s: listing failed and direct access failed", outputS3URI), - nil, providerName) + nil) } // Direct download succeeded, parse the content @@ -3697,9 +3427,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } if len(parseErrors) > 0 { @@ -3732,9 +3460,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: allResults, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: totalLatency, + Latency: totalLatency, }, } @@ -3745,26 +3471,14 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return batchResultsResp, nil } -func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) (string, string) { - deployment := resolveBedrockDeployment(model, key) - // Default: use model/deployment directly - path := fmt.Sprintf("%s/%s", deployment, basePath) +func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) string { + path := fmt.Sprintf("%s/%s", model, basePath) // If ARN is present, Bedrock expects the ARN-scoped identifier if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.ARN != nil && key.BedrockKeyConfig.ARN.GetValue() != "" { - encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), deployment)) + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), model)) path = fmt.Sprintf("%s/%s", encodedModelIdentifier, basePath) } - return path, deployment -} - -func resolveBedrockDeployment(model string, key schemas.Key) string { - deployment := model - if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.Deployments != nil { - if mapped, ok := key.BedrockKeyConfig.Deployments[model]; ok && mapped != "" { - deployment = mapped - } - } - return deployment + return path } func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { @@ -3772,16 +3486,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Convert to Bedrock Converse format using the existing responses converter converseReq, convErr := ToBedrockResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr) } // Wrap in the CountTokens request envelope @@ -3790,11 +3498,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(countTokensReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Format the path with proper model identifier - path, deployment := provider.getModelPath("count-tokens", request.Model, key) + path := provider.getModelPath("count-tokens", request.Model, key) // Send the request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -3805,15 +3513,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc if isCountTokensUnsupported(bifrostErr) { estimated := estimateTokenCount(jsonData) return &schemas.BifrostCountTokensResponse{ - Model: deployment, + Model: request.Model, InputTokens: estimated, TotalTokens: &estimated, Object: "response.input_tokens", ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - ModelRequested: request.Model, - ModelDeployment: deployment, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3836,15 +3540,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc } // Convert to Bifrost format - response := bedrockResponse.ToBifrostCountTokensResponse(deployment) + response := bedrockResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequest } diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index dc4b4d6d69..4e1c659cd5 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -181,22 +181,22 @@ func TestBedrock(t *testing.T) { {Provider: schemas.Bedrock, Model: "claude-4-sonnet"}, {Provider: schemas.Bedrock, Model: "claude-4.5-sonnet"}, }, - EmbeddingModel: "cohere.embed-v4:0", - RerankModel: rerankModelARN, - ReasoningModel: "claude-4.5-sonnet", - PromptCachingModel: "claude-4.5-sonnet", - ImageEditModel: "amazon.nova-canvas-v1:0", - ImageVariationModel: "amazon.nova-canvas-v1:0", + EmbeddingModel: "cohere.embed-v4:0", + RerankModel: rerankModelARN, + ReasoningModel: "claude-4.5-sonnet", + PromptCachingModel: "claude-4.5-sonnet", + ImageEditModel: "amazon.nova-canvas-v1:0", + ImageVariationModel: "amazon.nova-canvas-v1:0", InterleavedThinkingModel: "global.anthropic.claude-opus-4-5-20251101-v1:0", - BatchExtraParams: batchExtraParams, - FileExtraParams: fileExtraParams, + BatchExtraParams: batchExtraParams, + FileExtraParams: fileExtraParams, Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, End2EndToolCalling: true, @@ -4175,17 +4175,17 @@ func TestToBedrockInvokeMessagesStreamResponse_NoDuplicateContentBlockStop(t *te { Type: schemas.ResponsesStreamResponseTypeOutputTextDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeContentPartDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeOutputItemDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, } diff --git a/core/providers/bedrock/chat.go b/core/providers/bedrock/chat.go index 71e7890935..6459df377b 100644 --- a/core/providers/bedrock/chat.go +++ b/core/providers/bedrock/chat.go @@ -247,8 +247,6 @@ func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Conte Usage: usage, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Bedrock, }, } diff --git a/core/providers/bedrock/embedding.go b/core/providers/bedrock/embedding.go index d9981a1e4d..2e2875e9cf 100644 --- a/core/providers/bedrock/embedding.go +++ b/core/providers/bedrock/embedding.go @@ -1,10 +1,10 @@ package bedrock import ( + "encoding/json" "fmt" "strings" - "github.com/maximhq/bifrost/core/providers/cohere" "github.com/maximhq/bifrost/core/schemas" ) @@ -19,11 +19,6 @@ func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) return nil, fmt.Errorf("no input text provided for embedding") } - // Validate dimensions parameter - Titan models do not support it - if bifrostReq.Params != nil && bifrostReq.Params.Dimensions != nil { - return nil, fmt.Errorf("amazon Titan embedding models do not support custom dimensions parameter") - } - titanReq := &BedrockTitanEmbeddingRequest{} // Set input text @@ -36,8 +31,26 @@ func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) } titanReq.InputText = embeddingText } + if bifrostReq.Params != nil { - titanReq.ExtraParams = bifrostReq.Params.ExtraParams + titanReq.Dimensions = bifrostReq.Params.Dimensions + if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok { + if b, ok := normalize.(bool); ok { + titanReq.Normalize = &b + } + } + // Forward remaining extra params (excluding normalize which is now a first-class field) + if len(bifrostReq.Params.ExtraParams) > 0 { + extra := make(map[string]interface{}) + for k, v := range bifrostReq.Params.ExtraParams { + if k != "normalize" { + extra[k] = v + } + } + if len(extra) > 0 { + titanReq.ExtraParams = extra + } + } } return titanReq, nil @@ -69,20 +82,81 @@ func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *sch return bifrostResponse } -// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format -// Reuses the Cohere converter since the format is identical -func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*cohere.CohereEmbeddingRequest, error) { +// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format. +// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body. +func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost embedding request is nil") } + if bifrostReq.Input == nil { + return nil, fmt.Errorf("no input provided for embedding") + } - // Reuse Cohere's converter - the format is identical for Bedrock - cohereReq := cohere.ToCohereEmbeddingRequest(bifrostReq) - if cohereReq == nil { - return nil, fmt.Errorf("failed to convert to Cohere embedding request") + req := &BedrockCohereEmbeddingRequest{} + + // Map texts + if bifrostReq.Input.Text != nil { + req.Texts = []string{*bifrostReq.Input.Text} + } else if len(bifrostReq.Input.Texts) > 0 { + req.Texts = bifrostReq.Input.Texts } - return cohereReq, nil + if bifrostReq.Params != nil { + extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams)) + for k, v := range bifrostReq.Params.ExtraParams { + extra[k] = v + } + + if v, ok := extra["input_type"]; ok { + if s, ok := v.(string); ok { + req.InputType = s + delete(extra, "input_type") + } + } + if v, ok := extra["truncate"]; ok { + if s, ok := v.(string); ok { + req.Truncate = &s + delete(extra, "truncate") + } + } + if v, ok := extra["embedding_types"]; ok { + if ss, ok := v.([]string); ok { + req.EmbeddingTypes = ss + delete(extra, "embedding_types") + } + } + if v, ok := extra["images"]; ok { + if ss, ok := v.([]string); ok { + req.Images = ss + delete(extra, "images") + } + } + if v, ok := extra["inputs"]; ok { + if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok { + req.Inputs = inputs + delete(extra, "inputs") + } + } + if v, ok := extra["max_tokens"]; ok { + switch n := v.(type) { + case int: + req.MaxTokens = &n + delete(extra, "max_tokens") + case float64: + i := int(n) + req.MaxTokens = &i + delete(extra, "max_tokens") + } + } + if bifrostReq.Params.Dimensions != nil { + req.OutputDimension = bifrostReq.Params.Dimensions + } + if len(extra) > 0 { + req.ExtraParams = extra + } + } + + return req, nil } // DetermineEmbeddingModelType determines the embedding model type from the model name @@ -96,3 +170,102 @@ func DetermineEmbeddingModelType(model string) (string, error) { return "", fmt.Errorf("unsupported embedding model: %s", model) } } + +// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format. +// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats" +// (the default, when no embedding_types are requested), and as a typed object when +// response_type is "embeddings_by_type". +func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) { + if r == nil { + return nil, fmt.Errorf("nil Bedrock Cohere embedding response") + } + + bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"} + + switch r.ResponseType { + case "embeddings_by_type": + // Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]} + var typed struct { + Float [][]float32 `json:"float"` + Base64 []string `json:"base64"` + Int8 [][]int8 `json:"int8"` + Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue + Binary [][]int8 `json:"binary"` + Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue + } + if err := json.Unmarshal(r.Embeddings, &typed); err != nil { + return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err) + } + if typed.Float != nil { + for i, emb := range typed.Float { + float64Emb := make([]float64, len(emb)) + for j, v := range emb { + float64Emb[j] = float64(v) + } + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, + }) + } + } + if typed.Base64 != nil { + for i, emb := range typed.Base64 { + e := emb + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e}, + }) + } + } + for i, emb := range typed.Int8 { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, + }) + } + for i, emb := range typed.Binary { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, + }) + } + for i, emb := range typed.Uint8 { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, + }) + } + for i, emb := range typed.Ubinary { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, + }) + } + + default: + // Default / "embeddings_floats": raw array form [[...], [...]] + var floats [][]float32 + if err := json.Unmarshal(r.Embeddings, &floats); err != nil { + return nil, fmt.Errorf("error parsing embeddings_floats: %w", err) + } + for i, emb := range floats { + float64Emb := make([]float64, len(emb)) + for j, v := range emb { + float64Emb[j] = float64(v) + } + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, + }) + } + } + + return bifrostResponse, nil +} diff --git a/core/providers/bedrock/images.go b/core/providers/bedrock/images.go index 8c9ba9569c..dc3c76edd4 100644 --- a/core/providers/bedrock/images.go +++ b/core/providers/bedrock/images.go @@ -34,6 +34,61 @@ func mapQualityToBedrock(quality *string) *string { } } +// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.") +func isStabilityAIModel(model string) bool { + return strings.Contains(strings.ToLower(model), "stability.") +} + +// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat +// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models. +// Stability AI is excluded here — it's handled separately because it also supports image edit. +func isPromptOnlyImageGenerationModel(model string) bool { + m := strings.ToLower(model) + return strings.Contains(m, "image") +} + +// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI +// flat request format used by Bedrock (stability.stable-image-* models). +func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) { + if request == nil { + return nil, fmt.Errorf("request is nil") + } + if request.Input == nil { + return nil, fmt.Errorf("request input is required") + } + + req := &StabilityAIImageGenerationRequest{ + Prompt: request.Input.Prompt, + } + + if request.Params != nil { + if request.Params.AspectRatio != nil { + req.AspectRatio = request.Params.AspectRatio + } + if request.Params.OutputFormat != nil { + req.OutputFormat = request.Params.OutputFormat + } + if request.Params.Seed != nil { + req.Seed = request.Params.Seed + } + if request.Params.NegativePrompt != nil { + req.NegativePrompt = request.Params.NegativePrompt + } + if request.Params.ExtraParams != nil { + // aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set + if req.AspectRatio == nil { + if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok { + delete(request.Params.ExtraParams, "aspect_ratio") + req.AspectRatio = ar + } + } + req.ExtraParams = request.Params.ExtraParams + } + } + + return req, nil +} + // ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) { if request == nil { @@ -41,7 +96,7 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } if request.Input == nil { - return nil, fmt.Errorf("request.Input is required") + return nil, fmt.Errorf("request input is required") } bedrockReq := &BedrockImageGenerationRequest{ @@ -98,7 +153,24 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } return bedrockReq, nil +} +// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to +// the native Bedrock invoke API response format used by Stability AI models. +// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas. +func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) { + if response == nil { + return nil, fmt.Errorf("response is nil") + } + result := &BedrockImageGenerationResponse{} + for _, d := range response.Data { + result.Images = append(result.Images, d.B64JSON) + } + if response.ImageGenerationResponseParameters != nil { + result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons + result.Seeds = response.ImageGenerationResponseParameters.Seeds + } + return result, nil } // ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request @@ -358,6 +430,292 @@ func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGener return config } +// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value +// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type. +func getStabilityAITaskTypeFromParams(t string) string { + switch strings.ToLower(t) { + case "inpainting", "inpaint": + return "inpaint" + case "outpainting", "outpaint": + return "outpaint" + case "background_removal", "remove_background": + return "remove-bg" + case "erase_object": + return "erase-object" + case "upscale_fast": + return "upscale-fast" + case "upscale_creative": + return "upscale-creative" + case "upscale_conservative": + return "upscale-conservative" + case "recolor": + return "recolor" + case "search_replace": + return "search-replace" + case "control_sketch": + return "control-sketch" + case "control_structure": + return "control-structure" + case "style_guide": + return "style-guide" + case "style_transfer": + return "style-transfer" + default: + return "" + } +} + +// getStabilityAIEditTaskType infers the Stability AI edit task from the model name. +// Returns an error if the model name does not match any known pattern. +func getStabilityAIEditTaskType(model string) (string, error) { + m := strings.ToLower(model) + switch { + case strings.Contains(m, "stable-creative-upscale"): + return "upscale-creative", nil + case strings.Contains(m, "stable-conservative-upscale"): + return "upscale-conservative", nil + case strings.Contains(m, "stable-fast-upscale"): + return "upscale-fast", nil + case strings.Contains(m, "stable-image-inpaint"): + return "inpaint", nil + case strings.Contains(m, "stable-outpaint"): + return "outpaint", nil + case strings.Contains(m, "stable-image-search-recolor"): + return "recolor", nil + case strings.Contains(m, "stable-image-search-replace"): + return "search-replace", nil + case strings.Contains(m, "stable-image-erase-object"): + return "erase-object", nil + case strings.Contains(m, "stable-image-remove-background"): + return "remove-bg", nil + case strings.Contains(m, "stable-image-control-sketch"): + return "control-sketch", nil + case strings.Contains(m, "stable-image-control-structure"): + return "control-structure", nil + case strings.Contains(m, "stable-image-style-guide"): + return "style-guide", nil + case strings.Contains(m, "stable-style-transfer"): + return "style-transfer", nil + default: + return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model) + } +} + +// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request +// format used by Bedrock edit models. Only fields valid for the detected task type are populated. +// deployment is the resolved model identifier (after applying any deployment alias mapping); it is +// used for task-type inference so that alias-mapped models route correctly. +func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) { + if request == nil || request.Input == nil { + return nil, fmt.Errorf("request or input is nil") + } + + var taskType string + if request.Params != nil && request.Params.Type != nil { + taskType = getStabilityAITaskTypeFromParams(*request.Params.Type) + } + if taskType == "" { + var err error + taskType, err = getStabilityAIEditTaskType(deployment) + if err != nil { + return nil, err + } + } + + req := &StabilityAIImageEditRequest{} + + // Image sourcing + if taskType == "style-transfer" { + if len(request.Input.Images) != 2 { + return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image") + } + if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 { + return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image") + } + initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image) + styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image) + req.InitImage = &initB64 + req.StyleImage = &styleB64 + } else { + if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 { + return nil, fmt.Errorf("at least one image is required") + } + imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image) + req.Image = &imageB64 + } + + // Common fields populated based on task allowlist + prompt := request.Input.Prompt + switch taskType { + case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure", + "style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer": + req.Prompt = &prompt + } + + // Negative prompt + if request.Params != nil && request.Params.NegativePrompt != nil { + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch", + "control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer": + req.NegativePrompt = request.Params.NegativePrompt + } + } + + // Seed + if request.Params != nil && request.Params.Seed != nil { + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch", + "control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer": + req.Seed = request.Params.Seed + } + } + + // Mask (from Params.Mask bytes) + if request.Params != nil && len(request.Params.Mask) > 0 { + switch taskType { + case "inpaint", "erase-object": + maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask) + req.Mask = &maskB64 + } + } + + // ExtraParams + if request.Params != nil { + // Typed OutputFormat takes priority over ExtraParams + if request.Params.OutputFormat != nil { + req.OutputFormat = request.Params.OutputFormat + } + + if request.Params.ExtraParams != nil { + ep := make(map[string]interface{}, len(request.Params.ExtraParams)) + for k, v := range request.Params.ExtraParams { + ep[k] = v + } + + // output_format — all tasks (fallback if not already set by typed field) + if req.OutputFormat == nil { + if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok { + delete(ep, "output_format") + req.OutputFormat = v + } + } + + // style_preset + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch", + "control-structure", "style-guide", "upscale-creative": + if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok { + delete(ep, "style_preset") + req.StylePreset = v + } + } + + // grow_mask + switch taskType { + case "inpaint", "recolor", "search-replace", "erase-object": + if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok { + delete(ep, "grow_mask") + req.GrowMask = v + } + } + + // outpaint directional fields + if taskType == "outpaint" { + if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok { + delete(ep, "left") + req.Left = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok { + delete(ep, "right") + req.Right = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok { + delete(ep, "up") + req.Up = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok { + delete(ep, "down") + req.Down = v + } + } + + // creativity + switch taskType { + case "upscale-creative", "upscale-conservative", "outpaint": + if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok { + delete(ep, "creativity") + req.Creativity = v + } + } + + // select_prompt (recolor) + if taskType == "recolor" { + if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok { + delete(ep, "select_prompt") + req.SelectPrompt = v + } + } + + // search_prompt (search-replace) + if taskType == "search-replace" { + if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok { + delete(ep, "search_prompt") + req.SearchPrompt = v + } + } + + // control_strength + switch taskType { + case "control-sketch", "control-structure": + if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok { + delete(ep, "control_strength") + req.ControlStrength = v + } + } + + // style-guide fields + if taskType == "style-guide" { + if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok { + delete(ep, "aspect_ratio") + req.AspectRatio = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok { + delete(ep, "fidelity") + req.Fidelity = v + } + } + + // style-transfer fields + if taskType == "style-transfer" { + if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok { + delete(ep, "style_strength") + req.StyleStrength = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok { + delete(ep, "composition_fidelity") + req.CompositionFidelity = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok { + delete(ep, "change_strength") + req.ChangeStrength = v + } + } + + req.ExtraParams = ep + } + } + + // Validate required per-task fields + if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") { + return nil, fmt.Errorf("select_prompt is required for stability ai recolor task") + } + if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") { + return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task") + } + + return req, nil +} + // ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse { if response == nil { @@ -366,6 +724,13 @@ func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) bifrostResponse := &schemas.BifrostImageGenerationResponse{} + if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 { + bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{ + FinishReasons: append([]*string(nil), response.FinishReasons...), + Seeds: append([]int(nil), response.Seeds...), + } + } + for index, image := range response.Images { bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{ B64JSON: image, diff --git a/core/providers/bedrock/invoke.go b/core/providers/bedrock/invoke.go index bbd090a214..8227e8639a 100644 --- a/core/providers/bedrock/invoke.go +++ b/core/providers/bedrock/invoke.go @@ -2,8 +2,10 @@ package bedrock import ( "bytes" + "encoding/base64" "encoding/json" "fmt" + "net/url" "strings" "github.com/bytedance/sonic" @@ -44,6 +46,17 @@ var bedrockInvokeRequestKnownFields = map[string]bool{ "message": true, "chat_history": true, // AI21 "n": true, "frequency_penalty": true, "presence_penalty": true, + // Bedrock image gen / edit / variation (Titan/Nova Canvas) + "taskType": true, "textToImageParams": true, "imageVariationParams": true, + "inPaintingParams": true, "outPaintingParams": true, "backgroundRemovalParams": true, + "imageGenerationConfig": true, + // Stability AI image + "image": true, "mask": true, "negative_prompt": true, + "aspect_ratio": true, "output_format": true, "seed": true, + // Embeddings + "inputText": true, "texts": true, "input_type": true, + "normalize": true, "dimensions": true, + "embedding_types": true, "output_dimension": true, "inputs": true, // Internal "stream": true, "extra_params": true, } @@ -125,17 +138,74 @@ func (r *BedrockInvokeRequest) UnmarshalJSON(data []byte) error { return nil } -// DetectInvokeRequestType determines the request type from raw JSON body -// without full deserialization, keeping detection logic colocated with IsMessagesRequest. -func DetectInvokeRequestType(body []byte) schemas.RequestType { - node, _ := sonic.Get(body, "messages") - if node.Exists() { - raw, err := node.Raw() - if err == nil && raw != "null" && raw != "[]" { +// DetectInvokeRequestType determines the request type from raw JSON body and model ID +// without full deserialization, keeping detection logic colocated with conversion methods. +func DetectInvokeRequestType(body []byte, modelID string) schemas.RequestType { + // Messages → chat/responses path + if node, _ := sonic.Get(body, "messages"); node.Exists() { + if raw, err := node.Raw(); err == nil && raw != "null" && raw != "[]" { return schemas.ResponsesRequest } + } + + // Titan uses "inputText" for both embeddings and text generation. + // Use the model ID to disambiguate: embedding models contain "embed". + if node, _ := sonic.Get(body, "inputText"); node.Exists() { + if strings.Contains(strings.ToLower(modelID), "embed") { + return schemas.EmbeddingRequest + } return schemas.TextCompletionRequest } + + // Cohere embedding: text-only (texts), image-only (images), or mixed (inputs). + // Use model ID to identify embed models, then check for any non-empty payload field. + if strings.Contains(strings.ToLower(modelID), "embed") { + for _, field := range []string{"texts", "images", "inputs"} { + if node, _ := sonic.Get(body, field); node.Exists() { + if raw, err := node.Raw(); err == nil && raw != "null" && raw != "[]" { + return schemas.EmbeddingRequest + } + } + } + } + + // taskType-based image routing + if taskNode, _ := sonic.Get(body, "taskType"); taskNode.Exists() { + taskType, _ := taskNode.String() + switch taskType { + case TaskTypeTextImage: + return schemas.ImageGenerationRequest + case TaskTypeImageVariation: + return schemas.ImageVariationRequest + case TaskTypeInpainting, TaskTypeOutpainting, TaskTypeBackgroundRemoval: + return schemas.ImageEditRequest + } + } + + // URL-decode the model ID once for all model-name checks below + decodedModelID := modelID + if unescaped, err := url.PathUnescape(modelID); err == nil { + decodedModelID = unescaped + } + + // Stability AI: supports both generation (prompt-only) and edit (image+prompt) + if isStabilityAIModel(decodedModelID) { + if node, _ := sonic.Get(body, "image"); node.Exists() { + return schemas.ImageEditRequest + } + return schemas.ImageGenerationRequest + } + + // explicit image field -> edit request + if node, _ := sonic.Get(body, "image"); node.Exists() { + return schemas.ImageEditRequest + } + + // Checked after all body-field and model-specific signals so it doesn't shadow known models. + if isPromptOnlyImageGenerationModel(decodedModelID) { + return schemas.ImageGenerationRequest + } + return schemas.TextCompletionRequest } @@ -310,6 +380,382 @@ func (r *BedrockInvokeRequest) ToBifrostTextCompletionRequest(ctx *schemas.Bifro return textReq.ToBifrostTextCompletionRequest(ctx) } +// ToBifrostEmbeddingRequest converts the invoke request to a BifrostEmbeddingRequest. +// Handles both Titan (inputText) and Cohere (texts) embedding formats. +func (r *BedrockInvokeRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostEmbeddingRequest{ + Provider: provider, + Model: model, + } + + if r.InputText != "" { + req.Input = &schemas.EmbeddingInput{Text: &r.InputText} + } else if len(r.Texts) > 0 { + req.Input = &schemas.EmbeddingInput{Texts: r.Texts} + } + // image-only (r.Images) or mixed (r.Inputs): req.Input stays nil; data flows via ExtraParams + + extraParams := make(map[string]interface{}) + // Forward known embedding-only params into ExtraParams so the provider can pick them up + if r.InputType != nil { + extraParams["input_type"] = *r.InputType + } + if r.Normalize != nil { + extraParams["normalize"] = *r.Normalize + } + if len(r.EmbeddingTypes) > 0 { + extraParams["embedding_types"] = r.EmbeddingTypes + } + if r.Truncate != nil { + extraParams["truncate"] = *r.Truncate + } + if len(r.Images) > 0 { + extraParams["images"] = r.Images + } + if len(r.Inputs) > 0 { + extraParams["inputs"] = r.Inputs + } + if r.MaxTokens != nil { + extraParams["max_tokens"] = *r.MaxTokens + } + // Merge any remaining extra params from the request + for k, v := range r.ExtraParams { + extraParams[k] = v + } + + // output_dimension maps to Dimensions; prefer OutputDimension over Dimensions + dimensions := r.Dimensions + if r.OutputDimension != nil { + dimensions = r.OutputDimension + } + params := &schemas.EmbeddingParameters{ + Dimensions: dimensions, + } + if len(extraParams) > 0 { + params.ExtraParams = extraParams + } + req.Params = params + + return req +} + +// ToBifrostImageGenerationRequest converts the invoke request to a BifrostImageGenerationRequest. +// Handles Titan/Nova Canvas (taskType=TEXT_IMAGE with textToImageParams) and Stability AI (flat prompt fields). +func (r *BedrockInvokeRequest) ToBifrostImageGenerationRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageGenerationRequest { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageGenerationRequest{ + Provider: provider, + Model: model, + } + + params := &schemas.ImageGenerationParameters{ + NegativePrompt: r.NegativePrompt, + AspectRatio: r.AspectRatio, + N: r.N, + OutputFormat: r.OutputFormat, + Seed: r.Seed, + } + + if r.TextToImageParams != nil { + // Titan / Nova Canvas path + req.Input = &schemas.ImageGenerationInput{Prompt: r.TextToImageParams.Text} + if r.TextToImageParams.NegativeText != nil { + params.NegativePrompt = r.TextToImageParams.NegativeText + } + if r.TextToImageParams.Style != nil { + params.Style = r.TextToImageParams.Style + } + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + params.Seed = cfg.Seed + params.Quality = cfg.Quality + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.CfgScale != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + params.ExtraParams["cfgScale"] = *cfg.CfgScale + } + } + } else { + // Stability AI path — prompt comes from the top-level "prompt" field + req.Input = &schemas.ImageGenerationInput{Prompt: r.Prompt} + } + + // Forward any remaining ExtraParams + if len(r.ExtraParams) > 0 { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + for k, v := range r.ExtraParams { + params.ExtraParams[k] = v + } + } + + req.Params = params + return req +} + +// ToBifrostImageEditRequest converts the invoke request to a BifrostImageEditRequest. +// Handles Titan/Nova Canvas (taskType in INPAINTING/OUTPAINTING/BACKGROUND_REMOVAL) and Stability AI (flat image/mask fields). +func (r *BedrockInvokeRequest) ToBifrostImageEditRequest(ctx *schemas.BifrostContext) (*schemas.BifrostImageEditRequest, error) { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageEditRequest{ + Provider: provider, + Model: model, + } + params := &schemas.ImageEditParameters{ + NegativePrompt: r.NegativePrompt, + Seed: r.Seed, + } + + if r.TaskType != nil { + // Titan / Nova Canvas path + switch *r.TaskType { + case TaskTypeInpainting: + if r.InPaintingParams == nil { + return nil, fmt.Errorf("inPaintingParams required for INPAINTING task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.InPaintingParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode inpainting image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.InPaintingParams.Text, + } + params.Type = schemas.Ptr("inpainting") + if r.InPaintingParams.NegativeText != nil { + params.NegativePrompt = r.InPaintingParams.NegativeText + } + if r.InPaintingParams.MaskImage != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.InPaintingParams.MaskImage) + if err != nil { + return nil, fmt.Errorf("failed to decode inpainting mask: %w", err) + } + params.Mask = maskBytes + } + if r.InPaintingParams.MaskPrompt != nil || r.InPaintingParams.ReturnMask != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + if r.InPaintingParams.MaskPrompt != nil { + params.ExtraParams["mask_prompt"] = *r.InPaintingParams.MaskPrompt + } + if r.InPaintingParams.ReturnMask != nil { + params.ExtraParams["return_mask"] = *r.InPaintingParams.ReturnMask + } + } + + case TaskTypeOutpainting: + if r.OutPaintingParams == nil { + return nil, fmt.Errorf("outPaintingParams required for OUTPAINTING task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.OutPaintingParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode outpainting image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.OutPaintingParams.Text, + } + params.Type = schemas.Ptr("outpainting") + if r.OutPaintingParams.NegativeText != nil { + params.NegativePrompt = r.OutPaintingParams.NegativeText + } + if r.OutPaintingParams.MaskImage != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.OutPaintingParams.MaskImage) + if err != nil { + return nil, fmt.Errorf("failed to decode outpainting mask: %w", err) + } + params.Mask = maskBytes + } + if r.OutPaintingParams.MaskPrompt != nil || r.OutPaintingParams.ReturnMask != nil || r.OutPaintingParams.OutPaintingMode != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + if r.OutPaintingParams.MaskPrompt != nil { + params.ExtraParams["mask_prompt"] = *r.OutPaintingParams.MaskPrompt + } + if r.OutPaintingParams.ReturnMask != nil { + params.ExtraParams["return_mask"] = *r.OutPaintingParams.ReturnMask + } + if r.OutPaintingParams.OutPaintingMode != nil { + params.ExtraParams["outpainting_mode"] = *r.OutPaintingParams.OutPaintingMode + } + } + + case TaskTypeBackgroundRemoval: + if r.BackgroundRemovalParams == nil { + return nil, fmt.Errorf("backgroundRemovalParams required for BACKGROUND_REMOVAL task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.BackgroundRemovalParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode background removal image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + } + params.Type = schemas.Ptr("background_removal") + + default: + return nil, fmt.Errorf("unsupported taskType for image edit: %s", *r.TaskType) + } + + // Map imageGenerationConfig fields into edit params (Titan/Nova Canvas only) + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + params.Seed = cfg.Seed + params.Quality = cfg.Quality + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.CfgScale != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + params.ExtraParams["cfgScale"] = *cfg.CfgScale + } + } + } else { + // Stability AI path + if r.Image == nil { + return nil, fmt.Errorf("image field is required for Stability AI image edit") + } + imgBytes, err := base64.StdEncoding.DecodeString(*r.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode stability AI image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.Prompt, + } + // Infer task type from model name + taskType, err := getStabilityAIEditTaskType(r.ModelID) + if err != nil { + return nil, fmt.Errorf("cannot determine Stability AI edit task: %w", err) + } + params.Type = &taskType + if r.Mask != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.Mask) + if err != nil { + return nil, fmt.Errorf("failed to decode stability AI mask: %w", err) + } + params.Mask = maskBytes + } + } + + if len(r.ExtraParams) > 0 { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}, len(r.ExtraParams)) + } + for k, v := range r.ExtraParams { + params.ExtraParams[k] = v + } + } + req.Params = params + return req, nil +} + +// ToBifrostImageVariationRequest converts the invoke request to a BifrostImageVariationRequest. +// Reads from imageVariationParams (Titan/Nova Canvas format). +func (r *BedrockInvokeRequest) ToBifrostImageVariationRequest(ctx *schemas.BifrostContext) (*schemas.BifrostImageVariationRequest, error) { + if r.ImageVariationParams == nil || len(r.ImageVariationParams.Images) == 0 { + return nil, fmt.Errorf("imageVariationParams.images is required for IMAGE_VARIATION") + } + + primaryBytes, err := base64.StdEncoding.DecodeString(r.ImageVariationParams.Images[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode primary variation image: %w", err) + } + + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageVariationRequest{ + Provider: provider, + Model: model, + Input: &schemas.ImageVariationInput{ + Image: schemas.ImageInput{Image: primaryBytes}, + }, + } + + params := &schemas.ImageVariationParameters{} + extraParams := make(map[string]interface{}) + + // Additional images (index 1+) stored under "images" key for the provider + if len(r.ImageVariationParams.Images) > 1 { + additionalImages := make([][]byte, 0, len(r.ImageVariationParams.Images)-1) + for _, imgB64 := range r.ImageVariationParams.Images[1:] { + imgBytes, err := base64.StdEncoding.DecodeString(imgB64) + if err != nil { + return nil, fmt.Errorf("failed to decode additional variation image: %w", err) + } + additionalImages = append(additionalImages, imgBytes) + } + extraParams["images"] = additionalImages + } + + // Text / negative text / similarity strength go to ExtraParams (provider reads them from there) + if r.ImageVariationParams.Text != nil { + extraParams["prompt"] = *r.ImageVariationParams.Text + } + if r.ImageVariationParams.NegativeText != nil { + extraParams["negativeText"] = *r.ImageVariationParams.NegativeText + } + if r.ImageVariationParams.SimilarityStrength != nil { + extraParams["similarityStrength"] = *r.ImageVariationParams.SimilarityStrength + } + + // ImageGenerationConfig → N, Size, Seed, Quality, CfgScale + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.Seed != nil { + extraParams["seed"] = *cfg.Seed + } + if cfg.Quality != nil { + extraParams["quality"] = *cfg.Quality + } + if cfg.CfgScale != nil { + extraParams["cfgScale"] = *cfg.CfgScale + } + } + + // Forward any remaining ExtraParams from the request body + for k, v := range r.ExtraParams { + extraParams[k] = v + } + if len(extraParams) > 0 { + params.ExtraParams = extraParams + } + + req.Params = params + return req, nil +} + // buildCohereCommandRPrompt converts Cohere Command R's message + chat_history into a text prompt. func (r *BedrockInvokeRequest) buildCohereCommandRPrompt() string { var sb strings.Builder @@ -448,9 +894,16 @@ func ToBedrockInvokeMessagesResponse(ctx *schemas.BifrostContext, resp *schemas. return nil, fmt.Errorf("bifrost response is nil") } - model := resp.Model - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested + model := "" + if resp.Model != "" { + model = resp.Model + } else { + extraFields := resp.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } } // Nova models: delegate to existing ToBedrockConverseResponse (Nova InvokeModel matches Converse format) @@ -467,6 +920,101 @@ func ToBedrockInvokeMessagesResponse(ctx *schemas.BifrostContext, resp *schemas. return toBedrockInvokeAnthropicResponse(resp, model), nil } +func ToBedrockInvokeImagesResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostImageGenerationResponse) (interface{}, error) { + if resp == nil { + return nil, fmt.Errorf("bifrost response is nil") + } + + // If the provider stored the raw Bedrock response, return it verbatim (preserves seeds, finish_reasons, etc.) + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + + model := resp.Model + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } + } + + // Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas + if isStabilityAIModel(model) { + return ToStabilityAIImageGenerationResponse(resp) + } + + // Default: Titan Image Generator v1/v2, Nova Canvas — reconstruct from Bifrost data + result := &BedrockImageGenerationResponse{} + for _, d := range resp.Data { + result.Images = append(result.Images, d.B64JSON) + } + return result, nil +} + +// ToBedrockEmbeddingInvokeResponse converts a BifrostEmbeddingResponse back to the native +// Bedrock invoke API response format. +// Single-embedding (Titan) responses use: {"embedding": [...], "inputTextTokenCount": N} +// Multi-embedding (Cohere) responses use: {"embeddings": [[...],[...]], "response_type": "embeddings_floats"} +func ToBedrockEmbeddingInvokeResponse(resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + if resp == nil { + return nil, fmt.Errorf("bifrost embedding response is nil") + } + + // If the provider stored the raw Bedrock response, return it verbatim + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + + tokenCount := 0 + if resp.Usage != nil { + tokenCount = resp.Usage.PromptTokens + } + + if len(resp.Data) == 0 { + return &BedrockInvokeEmbeddingResp{InputTextTokenCount: tokenCount}, nil + } + + // Use model name to distinguish Cohere from Titan — not batch size. + // A single-input Cohere request must still return the Cohere envelope format. + model := resp.Model + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } + } + + if strings.Contains(strings.ToLower(model), "cohere") { + floats := make([][]float32, 0, len(resp.Data)) + for _, d := range resp.Data { + float32Emb := make([]float32, len(d.Embedding.EmbeddingArray)) + for i, v := range d.Embedding.EmbeddingArray { + float32Emb[i] = float32(v) + } + floats = append(floats, float32Emb) + } + return &BedrockInvokeCohereEmbeddingResp{ + Embeddings: floats, + ResponseType: "embeddings_floats", + }, nil + } + + // Titan format + if resp.Data[0].Embedding.EmbeddingArray == nil { + return &BedrockInvokeEmbeddingResp{InputTextTokenCount: tokenCount}, nil + } + float32Emb := make([]float32, len(resp.Data[0].Embedding.EmbeddingArray)) + for i, v := range resp.Data[0].Embedding.EmbeddingArray { + float32Emb[i] = float32(v) + } + return &BedrockInvokeEmbeddingResp{ + Embedding: float32Emb, + InputTextTokenCount: tokenCount, + }, nil +} + // toBedrockInvokeAnthropicResponse converts BifrostResponsesResponse to Anthropic Messages API format. func toBedrockInvokeAnthropicResponse(resp *schemas.BifrostResponsesResponse, model string) *BedrockInvokeMessagesResponse { result := &BedrockInvokeMessagesResponse{ @@ -623,12 +1171,17 @@ func ToBedrockInvokeMessagesStreamResponse(ctx *schemas.BifrostContext, resp *sc // final Completed event). Without checking resp.ExtraFields, early chunks would // have model="" and Nova streams would be mis-routed through the Anthropic path. model := "" - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.ExtraFields.ModelRequested != "" { - model = resp.Response.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.Model != "" { - model = resp.Response.Model + if resp.Response != nil { + if resp.Response.Model != "" { + model = resp.Response.Model + } else { + extraFields := resp.Response.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } + } } // Nova models: delegate to existing converse stream response (same format) @@ -657,6 +1210,7 @@ func ToBedrockInvokeMessagesStreamResponse(ctx *schemas.BifrostContext, resp *sc bedrockEvent := &BedrockStreamEvent{ InvokeModelRawChunks: rawChunks, } + return "", bedrockEvent, nil } @@ -669,8 +1223,11 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) switch resp.Type { case schemas.ResponsesStreamResponseTypeCreated: - // message_start — use ExtraFields.ModelRequested as fallback for early chunks - model := resp.ExtraFields.ModelRequested + // message_start — prefer resolved model for accurate family detection on early chunks + model := resp.ExtraFields.ResolvedModelUsed + if model == "" { + model = resp.ExtraFields.OriginalModelRequested + } msgStart := map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ @@ -780,7 +1337,7 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) "type": "content_block_delta", "index": idx, "delta": map[string]interface{}{ - "type": "input_json_delta", + "type": "input_json_delta", "partial_json": *resp.Delta, }, } diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index e4e96a8017..549db2e3bd 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -1,7 +1,6 @@ package bedrock import ( - "slices" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -82,147 +81,7 @@ type BedrockRerankResponseDocument struct { TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"` } -// regionPrefixes is a list of region prefixes used in Bedrock deployments -// Based on AWS region naming patterns and Bedrock deployment configurations -var regionPrefixes = []string{ - "us.", // US regions (us-east-1, us-west-2, etc.) - "eu.", // Europe regions (eu-west-1, eu-central-1, etc.) - "ap.", // Asia Pacific regions (ap-southeast-1, ap-northeast-1, etc.) - "ca.", // Canada regions (ca-central-1, etc.) - "sa.", // South America regions (sa-east-1, etc.) - "af.", // Africa regions (af-south-1, etc.) - "global.", // Global deployment prefix -} - -// extractPrefix extracts the region prefix ending with '.' from a string -// Only recognizes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the prefix (including the dot) if found, empty string otherwise -func extractPrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return prefix - } - } - return "" -} - -// removePrefix removes any region prefix ending with '.' from a string -// Only removes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the string without the prefix -func removePrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return s[len(prefix):] - } - } - return s -} - -// findMatchingAllowedModel finds a matching item in a slice, considering both -// exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact matches - if slices.Contains(slice, value) { - return value - } - - // Check with region prefix added/removed - valuePrefix := extractPrefix(value) - if valuePrefix != "" { - // value has a prefix, check if slice contains version without prefix - withoutPrefix := removePrefix(value) - if slices.Contains(slice, withoutPrefix) { - return withoutPrefix - } - } - - // Check if any item in slice has a prefix that matches value without prefix - for _, item := range slice { - itemPrefix := extractPrefix(item) - if itemPrefix != "" { - // item has prefix, check if value matches without the prefix - itemWithoutPrefix := removePrefix(item) - if itemWithoutPrefix == value { - return item - } - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize value by removing any region prefix for base model comparison - valueNormalized := removePrefix(value) - - for _, item := range slice { - // Normalize item by removing any region prefix for base model comparison - itemNormalized := removePrefix(item) - - // Check base model match with normalized values (prefix removed from both) - // Return the item from slice (not value) to use the actual name from allowedModels - if schemas.SameBaseModel(itemNormalized, valueNormalized) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// The modelID from the API response should match a deployment value (not the alias/key). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check if any deployment value matches the modelID (with or without prefix) - for aliasKey, deploymentValue := range deployments { - // Exact match - if deploymentValue == modelID || aliasKey == modelID { - return deploymentValue, aliasKey - } - - // Check prefix variations - deploymentPrefix := extractPrefix(deploymentValue) - modelIDPrefix := extractPrefix(modelID) - aliasKeyPrefix := extractPrefix(aliasKey) - - // Case 1: deploymentValue or aliasKey has prefix, modelID doesn't - if (deploymentPrefix != "" && modelIDPrefix == "") || (aliasKeyPrefix != "" && modelIDPrefix == "") { - if removePrefix(deploymentValue) == modelID || removePrefix(aliasKey) == modelID { - return deploymentValue, aliasKey - } - } - - // Case 2: modelID or aliasKey has prefix, deploymentValue doesn't - if (modelIDPrefix != "" && deploymentPrefix == "") || (aliasKeyPrefix != "" && deploymentPrefix == "") { - if removePrefix(modelID) == deploymentValue || removePrefix(modelID) == aliasKey { - return deploymentValue, aliasKey - } - } - - // Case 3: Both have prefixes but different prefixes - if (deploymentPrefix != "" && modelIDPrefix != "" && deploymentPrefix != modelIDPrefix) || (aliasKeyPrefix != "" && modelIDPrefix != "" && aliasKeyPrefix != modelIDPrefix) { - if removePrefix(deploymentValue) == removePrefix(modelID) || removePrefix(aliasKey) == removePrefix(modelID) { - return deploymentValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize both values by removing any region prefix for base model comparison - deploymentNormalized := removePrefix(deploymentValue) - modelIDNormalized := removePrefix(modelID) - - // Check base model match with normalized values (prefix removed from both) - if schemas.SameBaseModel(deploymentNormalized, modelIDNormalized) { - return deploymentValue, aliasKey - } - } - return "", "" -} - -func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -231,121 +90,41 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK Data: make([]schemas.Model, 0, len(response.ModelSummaries)), } - deploymentValues := make([]string, 0, len(deployments)) - for _, deployment := range deployments { - deploymentValues = append(deploymentValues, deployment) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), } - - includedModels := make(map[string]bool) - for _, model := range response.ModelSummaries { - modelID := model.ModelID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering global prefix variations - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deploymentValues) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) - } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deploymentValues) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ModelID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ModelID, modelID, deploymentAlias, matchedAllowedModel) { - continue - } - - modelEntry := schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.ModelName), - OwnedBy: schemas.Ptr(model.ProviderName), - Architecture: &schemas.Architecture{ - InputModalities: model.InputModalities, - OutputModalities: model.OutputModalities, - }, - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(providerKey) + "/" + deploymentAlias - // Use the actual deployment value (which might have global prefix) - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true - } else { - includedModels[modelID] = true - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[alias] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[alias] = true - } - } + included := make(map[string]bool) - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { - continue + for _, model := range response.ModelSummaries { + for _, result := range pipeline.FilterModel(model.ModelID) { + modelEntry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.ModelName), + OwnedBy: schemas.Ptr(model.ProviderName), + Architecture: &schemas.Architecture{ + InputModalities: model.InputModalities, + OutputModalities: model.OutputModalities, + }, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/bedrock/rerank_test.go b/core/providers/bedrock/rerank_test.go index 0dff5c3ee2..c1b7bb5480 100644 --- a/core/providers/bedrock/rerank_test.go +++ b/core/providers/bedrock/rerank_test.go @@ -195,27 +195,23 @@ func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) { func TestResolveBedrockDeployment(t *testing.T) { key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", }, } - deployment := resolveBedrockDeployment("cohere-rerank", key) + deployment := key.Aliases.Resolve("cohere-rerank") assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment) - assert.Equal(t, "cohere.rerank-v3-5:0", resolveBedrockDeployment("cohere.rerank-v3-5:0", key)) - assert.Equal(t, "", resolveBedrockDeployment("", key)) + assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0")) + assert.Equal(t, "", key.Aliases.Resolve("")) } func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) { provider := &BedrockProvider{} ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "cohere.rerank-v3-5:0", }, } diff --git a/core/providers/bedrock/s3.go b/core/providers/bedrock/s3.go index da06e5e820..be2d0afb32 100644 --- a/core/providers/bedrock/s3.go +++ b/core/providers/bedrock/s3.go @@ -22,7 +22,6 @@ func uploadToS3( region string, bucket, key string, content []byte, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Create AWS config with credentials var cfg aws.Config @@ -47,7 +46,7 @@ func uploadToS3( } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load AWS config for S3", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err) } // Create S3 client @@ -62,7 +61,7 @@ func uploadToS3( }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to S3: %s/%s", bucket, key), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err) } return nil diff --git a/core/providers/bedrock/signer.go b/core/providers/bedrock/signer.go index 9f12e3bbaf..b7e87ae8d2 100644 --- a/core/providers/bedrock/signer.go +++ b/core/providers/bedrock/signer.go @@ -280,17 +280,16 @@ func signAWSRequestFastHTTP( accessKey, secretKey string, sessionToken *string, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Get AWS credentials if not provided if accessKey == "" && secretKey == "" { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } accessKey = creds.AccessKeyID secretKey = creds.SecretAccessKey diff --git a/core/providers/bedrock/text.go b/core/providers/bedrock/text.go index 6ad24ee1c8..d31d716ded 100644 --- a/core/providers/bedrock/text.go +++ b/core/providers/bedrock/text.go @@ -127,8 +127,6 @@ func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -154,8 +152,6 @@ func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *s Object: "text_completion", Choices: choices, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -167,11 +163,14 @@ func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionR return nil } - // Determine response format based on model - // Use ModelRequested from ExtraFields if available, otherwise use Model + // Determine response format based on resolved model identity. + // Use ResolvedModelUsed (actual provider ID) for accurate family detection, + // falling back to bifrostResp.Model, then OriginalModelRequested as a last resort. model := bifrostResp.Model - if bifrostResp.ExtraFields.ModelRequested != "" { - model = bifrostResp.ExtraFields.ModelRequested + if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" { + model = bifrostResp.ExtraFields.OriginalModelRequested } if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") { diff --git a/core/providers/bedrock/transport_test.go b/core/providers/bedrock/transport_test.go index 6751527b5b..1e2a447e9d 100644 --- a/core/providers/bedrock/transport_test.go +++ b/core/providers/bedrock/transport_test.go @@ -138,7 +138,7 @@ func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) { ctx := testBedrockCtx() key := testBedrockKey() - _, _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") + _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") require.NotNil(t, bifrostErr, "expected error when server closes connection") assert.False(t, bifrostErr.IsBifrostError, diff --git a/core/providers/bedrock/types.go b/core/providers/bedrock/types.go index 98c46ae96f..afbfead01f 100644 --- a/core/providers/bedrock/types.go +++ b/core/providers/bedrock/types.go @@ -654,10 +654,10 @@ type BedrockMetadataEvent struct { // BedrockTitanEmbeddingRequest represents a Bedrock Titan embedding request type BedrockTitanEmbeddingRequest struct { - InputText string `json:"inputText"` // Required: Text to embed + InputText string `json:"inputText"` // Required: Text to embed + Dimensions *int `json:"dimensions,omitempty"` // Optional: 256, 512, or 1024 (titan-embed-text-v2 only) + Normalize *bool `json:"normalize,omitempty"` // Optional: normalize the embedding ExtraParams map[string]interface{} `json:"-"` - // Note: Titan models have fixed dimensions and don't support the dimensions parameter - // ExtraParams can be used for any additional model-specific parameters } // GetExtraParams implements the RequestBodyWithExtraParams interface @@ -671,6 +671,53 @@ type BedrockTitanEmbeddingResponse struct { InputTextTokenCount int `json:"inputTextTokenCount"` // Number of tokens in input } +// BedrockCohereEmbeddingContentBlock represents a single content block in a mixed input +type BedrockCohereEmbeddingContentBlock struct { + Type string `json:"type"` // "text" or "image_url" + Text *string `json:"text,omitempty"` // for type=text + ImageURL *BedrockCohereEmbeddingImageURL `json:"image_url,omitempty"` // for type=image_url +} + +// BedrockCohereEmbeddingImageURL holds the URL for an image content block +type BedrockCohereEmbeddingImageURL struct { + URL string `json:"url"` +} + +// BedrockCohereEmbeddingInput represents a mixed text+image input +type BedrockCohereEmbeddingInput struct { + Content []BedrockCohereEmbeddingContentBlock `json:"content"` +} + +// BedrockCohereEmbeddingRequest represents a Bedrock Cohere embedding request. +// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the body. +type BedrockCohereEmbeddingRequest struct { + InputType string `json:"input_type"` // Required + Texts []string `json:"texts,omitempty"` // text-only inputs + Images []string `json:"images,omitempty"` // image-only inputs (data URIs) + Inputs []BedrockCohereEmbeddingInput `json:"inputs,omitempty"` // mixed text+image inputs + EmbeddingTypes []string `json:"embedding_types,omitempty"` // e.g. ["float"] + OutputDimension *int `json:"output_dimension,omitempty"` // 256, 512, 1024, or 1536 + MaxTokens *int `json:"max_tokens,omitempty"` // max 128000 + Truncate *string `json:"truncate,omitempty"` // NONE, LEFT, or RIGHT + ExtraParams map[string]interface{} `json:"-"` +} + +// GetExtraParams implements the RequestBodyWithExtraParams interface +func (req *BedrockCohereEmbeddingRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// BedrockCohereEmbeddingResponse handles both Bedrock Cohere embedding response shapes. +// When embedding_types is not set, Bedrock returns embeddings as a raw [][]float32 +// ("embeddings_floats"). When embedding_types is set, it returns an object with typed +// arrays ("embeddings_by_type"). Using json.RawMessage defers parsing until we know the shape. +type BedrockCohereEmbeddingResponse struct { + ID string `json:"id"` + Embeddings json.RawMessage `json:"embeddings"` + ResponseType string `json:"response_type"` + Texts []string `json:"texts,omitempty"` +} + const TaskTypeTextImage = "TEXT_IMAGE" const TaskTypeImageVariation = "IMAGE_VARIATION" const TaskTypeInpainting = "INPAINTING" @@ -763,11 +810,79 @@ type BedrockBackgroundRemovalParams struct { Image string `json:"image"` // Base64-encoded image } -// BedrockImageGenerationResponse represents a Bedrock image generation response +// StabilityAIImageGenerationRequest represents the request format for Stability AI models on Bedrock +// (e.g. stability.stable-image-core-v1:1, stability.stable-image-ultra-v1:1) +type StabilityAIImageGenerationRequest struct { + Prompt string `json:"prompt"` + AspectRatio *string `json:"aspect_ratio,omitempty"` + OutputFormat *string `json:"output_format,omitempty"` + Seed *int `json:"seed,omitempty"` + NegativePrompt *string `json:"negative_prompt,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +// GetExtraParams implements the RequestBodyWithExtraParams interface +func (req *StabilityAIImageGenerationRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// StabilityAIImageEditRequest is the flat JSON body for Stability AI image-edit models on Bedrock. +// Only the fields valid for the detected task type are populated. +type StabilityAIImageEditRequest struct { + // Shared params + Image *string `json:"image,omitempty"` // base64, primary input image + Prompt *string `json:"prompt,omitempty"` + NegativePrompt *string `json:"negative_prompt,omitempty"` + Seed *int `json:"seed,omitempty"` + OutputFormat *string `json:"output_format,omitempty"` + StylePreset *string `json:"style_preset,omitempty"` + Mask *string `json:"mask,omitempty"` // base64 mask image + GrowMask *int `json:"grow_mask,omitempty"` + + // Outpaint + Left *int `json:"left,omitempty"` + Right *int `json:"right,omitempty"` + Up *int `json:"up,omitempty"` + Down *int `json:"down,omitempty"` + + // Upscale-creative / upscale-conservative / outpaint + Creativity *float64 `json:"creativity,omitempty"` + + // Recolor + SelectPrompt *string `json:"select_prompt,omitempty"` + + // Search-replace + SearchPrompt *string `json:"search_prompt,omitempty"` + + // Control-sketch / control-structure + ControlStrength *float64 `json:"control_strength,omitempty"` + + // Style-guide + AspectRatio *string `json:"aspect_ratio,omitempty"` + Fidelity *float64 `json:"fidelity,omitempty"` + + // Style-transfer (uses different image field names) + InitImage *string `json:"init_image,omitempty"` + StyleImage *string `json:"style_image,omitempty"` + StyleStrength *float64 `json:"style_strength,omitempty"` + CompositionFidelity *float64 `json:"composition_fidelity,omitempty"` + ChangeStrength *float64 `json:"change_strength,omitempty"` + + ExtraParams map[string]interface{} `json:"-"` +} + +func (req *StabilityAIImageEditRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// BedrockImageGenerationResponse represents a Bedrock image generation response. +// The Seeds and FinishReasons fields are populated by Stability AI edit models only. type BedrockImageGenerationResponse struct { - Images []string `json:"images"` // list of Base64 encoded images - MaskImage string `json:"maskImage"` // Base64 encoded mask image (optional) - Error string `json:"error"` // error message (if present) + Images []string `json:"images"` // list of Base64 encoded images + MaskImage string `json:"maskImage,omitempty"` // Base64 encoded mask image (optional) + Error string `json:"error,omitempty"` // error message (if present) + Seeds []int `json:"seeds,omitempty"` // Stability AI: seeds used per image + FinishReasons []*string `json:"finish_reasons,omitempty"` // Stability AI: finish reason per image (may be null) } // ==================== MODELS TYPES ==================== @@ -960,6 +1075,37 @@ type BedrockInvokeRequest struct { FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` PresencePenalty *float64 `json:"presence_penalty,omitempty"` + // ==================== BEDROCK IMAGE GEN / EDIT / VARIATION (Titan/Nova Canvas) ==================== + + TaskType *string `json:"taskType,omitempty"` + TextToImageParams *BedrockTextToImageParams `json:"textToImageParams,omitempty"` + ImageVariationParams *BedrockImageVariationParams `json:"imageVariationParams,omitempty"` + InPaintingParams *BedrockInPaintingParams `json:"inPaintingParams,omitempty"` + OutPaintingParams *BedrockOutPaintingParams `json:"outPaintingParams,omitempty"` + BackgroundRemovalParams *BedrockBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"` + ImageGenerationConfig *ImageGenerationConfig `json:"imageGenerationConfig,omitempty"` + + // ==================== STABILITY AI IMAGE ==================== + + // Image is the base64-encoded input image (SA edit / variation) + Image *string `json:"image,omitempty"` + Mask *string `json:"mask,omitempty"` // base64 mask for inpainting + NegativePrompt *string `json:"negative_prompt,omitempty"` // SA gen / edit + AspectRatio *string `json:"aspect_ratio,omitempty"` // SA gen + OutputFormat *string `json:"output_format,omitempty"` // SA gen + Seed *int `json:"seed,omitempty"` // SA gen / edit + + // ==================== EMBEDDINGS ==================== + + InputText string `json:"inputText,omitempty"` // Titan embed + Texts []string `json:"texts,omitempty"` // Cohere embed + InputType *string `json:"input_type,omitempty"` // Cohere embed + Normalize *bool `json:"normalize,omitempty"` // Titan embed v2 + Dimensions *int `json:"dimensions,omitempty"` // Titan embed v2 + EmbeddingTypes []string `json:"embedding_types,omitempty"` // Cohere embed: ["float","int8","uint8","binary","ubinary"] + OutputDimension *int `json:"output_dimension,omitempty"` // Cohere embed: 256, 512, 1024, 1536 + Inputs []BedrockCohereEmbeddingInput `json:"inputs,omitempty"` // Cohere embed: mixed text+image inputs + // ==================== INTERNAL ==================== Stream bool `json:"-"` ExtraParams map[string]interface{} `json:"-"` @@ -970,3 +1116,15 @@ type BedrockCohereRMessage struct { Role string `json:"role"` // "USER" or "CHATBOT" Message string `json:"message"` // Message content } + +// BedrockInvokeEmbeddingResp is the Titan single-embedding invoke response format. +type BedrockInvokeEmbeddingResp struct { + Embedding []float32 `json:"embedding"` + InputTextTokenCount int `json:"inputTextTokenCount"` +} + +// BedrockInvokeCohereEmbeddingResp is the Cohere multi-embedding invoke response format. +type BedrockInvokeCohereEmbeddingResp struct { + Embeddings [][]float32 `json:"embeddings"` + ResponseType string `json:"response_type"` +} diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index 665c2e81bc..c32dcd7374 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -178,9 +178,6 @@ func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/cohere/chat.go b/core/providers/cohere/chat.go index 1623d23737..fc19d18919 100644 --- a/core/providers/cohere/chat.go +++ b/core/providers/cohere/chat.go @@ -372,8 +372,6 @@ func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas }, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Cohere, }, } diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index fd50f6d6d5..1e5d50e087 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -155,7 +155,7 @@ func (provider *CohereProvider) buildRequestURL(ctx *schemas.BifrostContext, def // completeRequest sends a request to Cohere's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -199,10 +199,10 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, latency, providerResponseHeaders, parseCohereError(resp, meta) + return nil, latency, providerResponseHeaders, parseCohereError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, latency, providerResponseHeaders, decodeErr } @@ -217,8 +217,6 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -234,7 +232,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Parse and add query parameters u, err := url.Parse(baseURL) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse request URL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse request url", err) } q := u.Query() @@ -269,15 +267,12 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseCohereError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse Cohere list models response @@ -288,7 +283,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := cohereResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -352,17 +347,12 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereChatCompletionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -375,9 +365,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -396,9 +383,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -424,7 +408,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, err } - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -435,8 +418,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } reqBody.Stream = schemas.Ptr(true) return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -486,9 +468,9 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -497,11 +479,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -521,9 +499,9 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -561,7 +539,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -583,11 +561,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -595,11 +568,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext if response != nil { response.ID = responseID response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -639,18 +609,13 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereResponsesRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Convert to Cohere v2 request - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -663,9 +628,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -686,9 +648,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -712,7 +671,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, err } - providerName := provider.GetProviderKey() // Convert to Cohere v2 request and add streaming jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -726,8 +684,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos reqBody.Stream = schemas.Ptr(true) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -775,9 +732,9 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -786,11 +743,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -810,9 +763,9 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -854,8 +807,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + provider.logger.Warn("Error reading stream: %v", readErr) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -875,11 +828,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos responses, bifrostErr, isLastChunk := event.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -888,11 +836,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() chunkIndex++ @@ -936,18 +881,13 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Create Bifrost request for conversion - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -960,9 +900,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -981,9 +918,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := response.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1012,17 +946,12 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereRerankRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1035,9 +964,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1058,9 +984,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1218,16 +1141,12 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereCountTokensRequest(request) - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1237,11 +1156,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonBody, provider.buildRequestURL(ctx, "/v1/tokenize", schemas.CountTokensRequest), key.Value.GetValue(), - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1255,9 +1169,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return &schemas.BifrostCountTokensResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1279,12 +1190,9 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch bifrostResponse := cohereResponse.ToBifrostCountTokensResponse(request.Model) if bifrostResponse == nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil Cohere count tokens response"), providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil cohere count tokens response")), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.CountTokensRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders diff --git a/core/providers/cohere/cohere_test.go b/core/providers/cohere/cohere_test.go index f2fca1028f..23c73911bc 100644 --- a/core/providers/cohere/cohere_test.go +++ b/core/providers/cohere/cohere_test.go @@ -32,27 +32,27 @@ func TestCohere(t *testing.T) { RerankModel: "rerank-v3.5", ReasoningModel: "command-a-reasoning-08-2025", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not typical for Cohere - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, // May not support automatic - ImageURL: false, // Supported by c4ai-aya-vision-8b model - ImageBase64: true, // Supported by c4ai-aya-vision-8b model - MultipleImages: false, // Supported by c4ai-aya-vision-8b model - FileBase64: false, // Not supported - FileURL: false, // Not supported - CompleteEnd2End: false, - Embedding: true, - Rerank: true, - Reasoning: true, - ListModels: true, - CountTokens: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, // May not support automatic + ImageURL: false, // Supported by c4ai-aya-vision-8b model + ImageBase64: true, // Supported by c4ai-aya-vision-8b model + MultipleImages: false, // Supported by c4ai-aya-vision-8b model + FileBase64: false, // Not supported + FileURL: false, // Not supported + CompleteEnd2End: false, + Embedding: true, + Rerank: true, + Reasoning: true, + ListModels: true, + CountTokens: true, }, } diff --git a/core/providers/cohere/errors.go b/core/providers/cohere/errors.go index e9183b1b34..e444d86650 100644 --- a/core/providers/cohere/errors.go +++ b/core/providers/cohere/errors.go @@ -6,7 +6,7 @@ import ( "github.com/valyala/fasthttp" ) -func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp CohereError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) bifrostErr.Type = &errorResp.Type @@ -17,10 +17,5 @@ func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetada if errorResp.Code != nil { bifrostErr.Error.Code = errorResp.Code } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index 3df2aab89a..3b285f97b6 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -2,8 +2,9 @@ package cohere import ( "encoding/json" - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -44,7 +45,7 @@ type CohereRerankMeta struct { Tokens *CohereTokenUsage `json:"tokens,omitempty"` } -func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -53,37 +54,39 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - includedModels := make(map[string]bool) - for _, model := range response.Models { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.Name) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.Name) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.Name, - Name: schemas.Ptr(model.Name), - ContextLength: schemas.Ptr(int(model.ContextLength)), - SupportedMethods: model.Endpoints, - }) - includedModels[model.Name] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Models { + // Cohere uses model.Name as the model identifier + for _, result := range pipeline.FilterModel(model.Name) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + ContextLength: schemas.Ptr(int(model.ContextLength)), + SupportedMethods: model.Endpoints, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/cohere/types.go b/core/providers/cohere/types.go index a4d78f9a48..8e5aa31402 100644 --- a/core/providers/cohere/types.go +++ b/core/providers/cohere/types.go @@ -9,8 +9,11 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinimumReasoningMaxTokens = 1 -const DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +const ( + MinimumReasoningMaxTokens = 1 + DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +) + // Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request const ( cohereTokenizeMinTextLength = 1 diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index 31d3a4761a..bcd3e5cfc7 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "io" "mime/multipart" "net/http" @@ -74,8 +73,6 @@ func (provider *ElevenlabsProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -103,10 +100,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseElevenlabsError(resp) } var elevenlabsResponse ElevenlabsListModelsResponse @@ -115,7 +109,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, return nil, bifrostErr } - response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := elevenlabsResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -188,8 +182,6 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -211,7 +203,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche endpoint = "/v1/text-to-speech/" + voice } } else { - return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil) } requestURL := provider.buildBaseSpeechRequestURL(ctx, endpoint, schemas.SpeechRequest, request) @@ -228,8 +220,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -250,26 +241,18 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Get the response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response based on whether timestamps were requested bifrostResponse := &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -282,7 +265,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche if withTimestampsRequest { var timestampResponse ElevenlabsSpeechWithTimestampsResponse if err := sonic.Unmarshal(body, ×tampResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err) } bifrostResponse.AudioBase64 = ×tampResponse.AudioBase64 @@ -326,15 +309,12 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -350,7 +330,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) if request.Params == nil || request.Params.VoiceConfig == nil || request.Params.VoiceConfig.Voice == nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } req.SetRequestURI(provider.buildBaseSpeechRequestURL(ctx, "/v1/text-to-speech/"+*request.Params.VoiceConfig.Voice+"/stream", schemas.SpeechStreamRequest, request)) @@ -381,9 +361,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -392,11 +372,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response channel @@ -407,9 +383,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -451,7 +427,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -464,11 +440,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -487,11 +460,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDone, Audio: []byte{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -512,32 +482,30 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return nil, err } - providerName := provider.GetProviderKey() - reqBody := ToElevenlabsTranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil) } hasFile := len(reqBody.File) > 0 hasURL := reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" if hasFile && hasURL { - return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil) } if !hasFile && !hasURL { - return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil) } var body bytes.Buffer writer := multipart.NewWriter(&body) - if bifrostErr := writeTranscriptionMultipart(writer, reqBody, providerName); bifrostErr != nil { + if bifrostErr := writeTranscriptionMultipart(writer, reqBody); bifrostErr != nil { return nil, bifrostErr } contentType := writer.FormDataContentType() if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err) } req := fasthttp.AcquireRequest() @@ -568,17 +536,12 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + return nil, parseElevenlabsError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -594,18 +557,15 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k chunks, err := parseTranscriptionResponse(responseBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } if len(chunks) == 0 { - return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil) } response := ToBifrostTranscriptionResponse(chunks) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } @@ -613,7 +573,7 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + rawResponse = string(responseBody) } response.ExtraFields.RawResponse = rawResponse } @@ -621,9 +581,9 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return response, nil } -func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError { +func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest) *schemas.BifrostError { if err := writer.WriteField("model_id", reqBody.ModelID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model_id field", err) } if len(reqBody.File) > 0 { @@ -633,98 +593,98 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create file field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create file field", err) } if _, err := fileWriter.Write(reqBody.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } } if reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" { if err := writer.WriteField("cloud_storage_url", *reqBody.CloudStorageURL); err != nil { - return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err) } } if reqBody.LanguageCode != nil && strings.TrimSpace(*reqBody.LanguageCode) != "" { if err := writer.WriteField("language_code", *reqBody.LanguageCode); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language_code field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language_code field", err) } } if reqBody.TagAudioEvents != nil { if err := writer.WriteField("tag_audio_events", strconv.FormatBool(*reqBody.TagAudioEvents)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err) } } if reqBody.NumSpeakers != nil && *reqBody.NumSpeakers > 0 { if err := writer.WriteField("num_speakers", strconv.Itoa(*reqBody.NumSpeakers)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err) } } if reqBody.TimestampsGranularity != nil && *reqBody.TimestampsGranularity != "" { if err := writer.WriteField("timestamps_granularity", string(*reqBody.TimestampsGranularity)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err) } } if reqBody.Diarize != nil { if err := writer.WriteField("diarize", strconv.FormatBool(*reqBody.Diarize)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarize field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarize field", err) } } if reqBody.DiarizationThreshold != nil { if err := writer.WriteField("diarization_threshold", strconv.FormatFloat(*reqBody.DiarizationThreshold, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err) } } if len(reqBody.AdditionalFormats) > 0 { payload, err := providerUtils.MarshalSorted(reqBody.AdditionalFormats) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err) } if err := writer.WriteField("additional_formats", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err) } } if reqBody.FileFormat != nil && *reqBody.FileFormat != "" { if err := writer.WriteField("file_format", string(*reqBody.FileFormat)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file_format field", err) } } if reqBody.Webhook != nil { if err := writer.WriteField("webhook", strconv.FormatBool(*reqBody.Webhook)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook field", err) } } if reqBody.WebhookID != nil && strings.TrimSpace(*reqBody.WebhookID) != "" { if err := writer.WriteField("webhook_id", *reqBody.WebhookID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err) } } if reqBody.Temperature != nil { if err := writer.WriteField("temperature", strconv.FormatFloat(*reqBody.Temperature, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } if reqBody.Seed != nil { if err := writer.WriteField("seed", strconv.Itoa(*reqBody.Seed)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seed field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seed field", err) } } if reqBody.UseMultiChannel != nil { if err := writer.WriteField("use_multi_channel", strconv.FormatBool(*reqBody.UseMultiChannel)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err) } } @@ -733,16 +693,16 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr case string: if strings.TrimSpace(v) != "" { if err := writer.WriteField("webhook_metadata", v); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } default: payload, err := providerUtils.MarshalSorted(v) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err) } if err := writer.WriteField("webhook_metadata", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } } diff --git a/core/providers/elevenlabs/errors.go b/core/providers/elevenlabs/errors.go index 374e251958..f30807efd5 100644 --- a/core/providers/elevenlabs/errors.go +++ b/core/providers/elevenlabs/errors.go @@ -9,7 +9,7 @@ import ( schemas "github.com/maximhq/bifrost/core/schemas" ) -func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseElevenlabsError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp ElevenlabsError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Detail != nil { @@ -64,11 +64,6 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe Message: message, }, } - if meta != nil { - result.ExtraFields.Provider = meta.Provider - result.ExtraFields.ModelRequested = meta.Model - result.ExtraFields.RequestType = meta.RequestType - } return result } } @@ -91,10 +86,5 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe bifrostErr.Error.Message = message } } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/elevenlabs/models.go b/core/providers/elevenlabs/models.go index c211e85196..f762d97ee8 100644 --- a/core/providers/elevenlabs/models.go +++ b/core/providers/elevenlabs/models.go @@ -1,12 +1,13 @@ package elevenlabs import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,35 +16,36 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid Data: make([]schemas.Model, 0, len(*response)), } - includedModels := make(map[string]bool) - for _, model := range *response { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.ModelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ModelID, - Name: schemas.Ptr(model.Name), - }) - includedModels[model.ModelID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range *response { + for _, result := range pipeline.FilterModel(model.ModelID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/elevenlabs/realtime.go b/core/providers/elevenlabs/realtime.go index f124f58339..a18e1cd514 100644 --- a/core/providers/elevenlabs/realtime.go +++ b/core/providers/elevenlabs/realtime.go @@ -39,6 +39,44 @@ func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string] return headers } +// SupportsRealtimeWebRTC returns false — ElevenLabs WebRTC SDP exchange is not yet implemented. +func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool { + return false +} + +// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs. +func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) { + return "", &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(400), + Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"}, + } +} + +func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + return false +} + +func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string { + return "" +} + +func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string { + return "" +} + +func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + return eventType == schemas.RTEventResponseDone +} + // ElevenLabs Conversational AI WebSocket event types const ( elConversationInitMetadata = "conversation_initiation_metadata" @@ -50,8 +88,8 @@ const ( elInterruption = "interruption" elClientToolCall = "client_tool_call" - elUserAudioChunk = "user_audio_chunk" - elPong = "pong" + elUserAudioChunk = "user_audio_chunk" + elPong = "pong" elClientToolResult = "client_tool_result" elContextualUpdate = "contextual_update" ) @@ -134,7 +172,7 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra } case elAgentResponse: - event.Type = schemas.RTEventResponseTextDone + event.Type = schemas.RTEventResponseDone if raw.AgentResponse != nil { var agentResp elevenlabsTranscriptEvent if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil { @@ -194,10 +232,6 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra // ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON. func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - switch bifrostEvent.Type { case schemas.RTEventInputAudioAppend: if bifrostEvent.Delta == nil { diff --git a/core/providers/gemini/batch.go b/core/providers/gemini/batch.go index e3d92383f6..8f0405e524 100644 --- a/core/providers/gemini/batch.go +++ b/core/providers/gemini/batch.go @@ -249,8 +249,6 @@ func extractBatchIDFromName(name string) string { // downloadBatchResultsFile downloads and parses a batch results file from Gemini. // Returns the parsed result items from the JSONL file and any parse errors encountered. func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request to download the file req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -287,15 +285,12 @@ func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse JSONL content - each line is a separate JSON object diff --git a/core/providers/gemini/errors.go b/core/providers/gemini/errors.go index adf217a141..2d60a7bcd3 100644 --- a/core/providers/gemini/errors.go +++ b/core/providers/gemini/errors.go @@ -36,7 +36,7 @@ func ToGeminiError(bifrostErr *schemas.BifrostError) *GeminiGenerationError { } // parseGeminiError parses Gemini error responses -func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseGeminiError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse as []GeminiGenerationError var errorResps []GeminiGenerationError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResps) @@ -62,11 +62,6 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada } // Set Message to trimmed concatenated message bifrostErr.Error.Message = message - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } @@ -80,10 +75,5 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(errorResp.Error.Code)) bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index f4c4a32c01..cad4216534 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -97,9 +97,7 @@ func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { // completeRequest handles the common HTTP request pattern for Gemini API calls. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, nil, latency, nil) — callers must check the context flag. -func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string, meta *providerUtils.RequestMetadata) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - +func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -146,10 +144,10 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp, meta) + return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, nil, latency, providerResponseHeaders, decodeErr } @@ -161,13 +159,13 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Parse Gemini's response var geminiResponse GenerateContentResponse if err := sonic.Unmarshal(body, &geminiResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var rawResponse interface{} if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } } @@ -208,10 +206,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseGeminiError(resp) } // Parse Gemini's response @@ -227,7 +222,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key } } - response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -282,24 +277,17 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - jsonData, err := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiChatCompletionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -312,9 +300,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -323,9 +308,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := geminiResponse.ToBifrostChatResponse() - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -363,8 +345,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion request is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -450,9 +431,9 @@ func HandleGeminiChatCompletionStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -462,11 +443,7 @@ func HandleGeminiChatCompletionStream( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) respBody := append([]byte(nil), resp.Body()...) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -484,9 +461,9 @@ func HandleGeminiChatCompletionStream( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -496,7 +473,6 @@ func HandleGeminiChatCompletionStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -558,7 +534,7 @@ func HandleGeminiChatCompletionStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } // Process chunk using shared function @@ -573,11 +549,6 @@ func HandleGeminiChatCompletionStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -598,11 +569,6 @@ func HandleGeminiChatCompletionStream( // Convert to Bifrost stream response response, bifrostErr, isLastChunk := geminiResponse.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -614,11 +580,8 @@ func HandleGeminiChatCompletionStream( response.Model = modelName } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -693,8 +656,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return nil, fmt.Errorf("responses input is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -706,11 +668,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem } // Use struct directly for JSON marshaling - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -723,9 +681,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -736,9 +691,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -768,13 +720,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( bodyReader io.Reader, // Optional: for large payload request streaming (pass nil for normal path) bodySize int, // Required if bodyReader is non-nil ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - meta := &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - } - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -808,14 +753,14 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Handle error response — materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - bifrostErr := parseGeminiError(resp, meta) + bifrostErr := parseGeminiError(resp) wait() fasthttp.ReleaseResponse(resp) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Delegate large response detection + normal buffered path to shared utility - responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -831,9 +776,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( Model: request.Model, Usage: usage, } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // resp owned by reader in context — don't release wait() @@ -845,12 +787,9 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Normal parse-and-convert path var geminiResponse GenerateContentResponse if unmarshalErr := sonic.Unmarshal(responseBody, &geminiResponse); unmarshalErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr) } bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -902,8 +841,7 @@ func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, fmt.Errorf("responses input is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -988,9 +926,9 @@ func HandleGeminiResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -999,11 +937,7 @@ func HandleGeminiResponsesStream( // Check for HTTP errors — use parseGeminiError to preserve upstream error details if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1021,9 +955,9 @@ func HandleGeminiResponsesStream( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1034,7 +968,6 @@ func HandleGeminiResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError( @@ -1103,7 +1036,7 @@ func HandleGeminiResponsesStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } @@ -1119,11 +1052,6 @@ func HandleGeminiResponsesStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1141,11 +1069,6 @@ func HandleGeminiResponsesStream( // Convert to Bifrost responses stream response responses, bifrostErr := geminiResponse.ToBifrostResponsesStream(sequenceNumber, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1154,11 +1077,8 @@ func HandleGeminiResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1211,11 +1131,8 @@ func HandleGeminiResponsesStream( continue } finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1260,8 +1177,7 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiEmbeddingRequest(request), nil - }, - providerName) + }) if err != nil { return nil, err } @@ -1312,17 +1228,13 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) wait() fasthttp.ReleaseResponse(resp) return nil, parsedErr } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -1335,9 +1247,6 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1359,12 +1268,9 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := ToBifrostEmbeddingResponse(&geminiResponse, request.Model) if bifrostResponse == nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, - fmt.Errorf("failed to convert Gemini embedding response to Bifrost format"), providerName) + fmt.Errorf("failed to convert Gemini embedding response to Bifrost format")) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1393,18 +1299,13 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.SpeechRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1416,9 +1317,6 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1430,13 +1328,10 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. } response, convErr := geminiResponse.ToBifrostSpeechResponse(ctx) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.SpeechRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1468,16 +1363,13 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using speech-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1523,9 +1415,9 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1534,11 +1426,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1558,9 +1446,9 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1600,7 +1488,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1620,11 +1508,6 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1675,11 +1558,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -1696,11 +1576,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDone, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } response.BackfillParams(request) @@ -1728,18 +1605,13 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1751,9 +1623,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostTranscriptionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TranscriptionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1763,9 +1632,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s response := geminiResponse.ToBifrostTranscriptionResponse() // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TranscriptionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1787,16 +1653,13 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using transcription-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1842,9 +1705,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1853,11 +1716,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1877,9 +1736,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1919,7 +1778,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1938,11 +1797,6 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1987,11 +1841,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Type: schemas.TranscriptionStreamResponseTypeDelta, Delta: &deltaText, // Delta text for this chunk ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -2014,11 +1865,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, TotalTokens: usage.TotalTokens, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2051,18 +1899,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageGenerationRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2074,9 +1917,6 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2085,25 +1925,16 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image generation response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - provider.GetProviderKey(), ) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2120,16 +1951,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key // handleImagenImageGeneration handles Imagen model requests using Vertex AI endpoint with API key auth func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Prepare Imagen request body jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2169,16 +1997,11 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Imagen response - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2186,10 +2009,7 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2201,9 +2021,6 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost } // Convert to Bifrost format response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2228,8 +2045,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, err } - providerName := provider.GetProviderKey() - // Handle Imagen models using :predict endpoint if schemas.IsImagenModel(request.Model) { jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2237,8 +2052,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2273,15 +2087,10 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2289,10 +2098,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem imagenRespOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2304,9 +2110,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2325,18 +2128,13 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2348,9 +2146,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2359,25 +2154,16 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image edit response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - providerName, ) } // Set ExtraFields - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2409,7 +2195,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() model := bifrostReq.Model jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2418,7 +2203,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2451,17 +2235,13 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // use handle provider response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2477,12 +2257,9 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = model - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -2500,10 +2277,9 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s return nil, err } - providerName := provider.GetProviderKey() operationID := bifrostReq.ID - operationID = providerUtils.StripVideoIDProviderSuffix(operationID, providerName) + operationID = providerUtils.StripVideoIDProviderSuffix(operationID, provider.GetProviderKey()) // Create HTTP request req := fasthttp.AcquireRequest() @@ -2528,10 +2304,8 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + respBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2545,12 +2319,10 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) // Add extra fields bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { bifrostResp.ExtraFields.RawResponse = rawResponse @@ -2564,9 +2336,8 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first so download behavior follows retrieve status. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2581,11 +2352,10 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), nil, - providerName, ) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte contentType := "video/mp4" @@ -2596,7 +2366,7 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded latency = time.Since(startTime) @@ -2627,17 +2397,16 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), nil, - providerName, ) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ VideoID: request.ID, @@ -2646,8 +2415,6 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -2677,18 +2444,16 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - // Validate that either InputFileID or Requests is provided, but not both hasFileInput := request.InputFileID != "" hasInlineRequests := len(request.Requests) > 0 if !hasFileInput && !hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil) } if hasFileInput && hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil) } // Build the batch request with proper nested structure @@ -2721,12 +2486,12 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch if rawMessages, ok := body["messages"]; ok { messagesBytes, err := providerUtils.MarshalSorted(rawMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err) } var chatMessages []schemas.ChatMessage err = sonic.Unmarshal(messagesBytes, &chatMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err) } contents, systemInstruction := convertBifrostMessagesToGemini(chatMessages) @@ -2736,11 +2501,11 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // If no "messages" key, try direct unmarshal (already in Gemini format) requestBytes, err := providerUtils.MarshalSorted(body) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err) } err = sonic.Unmarshal(requestBytes, &geminiReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err) } } @@ -2764,7 +2529,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(batchReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create HTTP request @@ -2802,31 +2567,27 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.BatchCreateRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse the batch job response var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { provider.logger.Error("gemini batch create unmarshal error: " + err.Error()) - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Check for metadata if geminiResp.Metadata == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil) } // Check for batch stats if geminiResp.Metadata.BatchStats == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil) } // Calculate request counts based on response totalRequests := geminiResp.Metadata.BatchStats.RequestCount @@ -2869,9 +2630,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2890,8 +2649,6 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // batchListByKey lists batch jobs for Gemini for a single key. func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2939,26 +2696,21 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format @@ -2970,10 +2722,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Status: ToBifrostBatchStatus(batch.Metadata.State), CreatedAt: parseGeminiTimestamp(batch.Metadata.CreateTime), OperationName: &batch.Name, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, + ExtraFields: schemas.BifrostResponseExtraFields{}, }) } @@ -2989,9 +2738,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key HasMore: hasMore, NextCursor: nextCursor, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } @@ -3005,16 +2752,14 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil) } // Initialize serial pagination helper (Gemini uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3025,10 +2770,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -3060,9 +2801,7 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3074,8 +2813,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // batchRetrieveByKey retrieves a specific batch job for Gemini for a single key. func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3108,20 +2845,17 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var completedCount, failedCount int @@ -3150,9 +2884,7 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3163,14 +2895,12 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil) } // Try each key until we find the batch @@ -3189,8 +2919,6 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // batchCancelByKey cancels a batch job for Gemini for a single key. func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3228,15 +2956,9 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke if resp.StatusCode() == fasthttp.StatusNotFound || resp.StatusCode() == fasthttp.StatusMethodNotAllowed { // 404 could mean batch not found or cancel not supported // Return the error instead of assuming completed - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } now := time.Now().Unix() @@ -3246,9 +2968,7 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke Status: schemas.BatchStatusCancelling, CancellingAt: &now, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3260,14 +2980,12 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil) } // Try each key until cancellation succeeds @@ -3289,8 +3007,6 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // batches.delete indicates the client is no longer interested in the operation result. // It does not cancel the operation. If the server doesn't support this method, it returns UNIMPLEMENTED. func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) @@ -3319,10 +3035,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke } if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostBatchDeleteResponse{ @@ -3330,9 +3043,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke Object: "batch", Status: schemas.BatchStatusDeleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3345,14 +3056,12 @@ func (provider *GeminiProvider) BatchDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil) } var lastError *schemas.BifrostError @@ -3500,8 +3209,6 @@ func readNextSSEDataLine(reader *bufio.Reader, skipInlineData bool) ([]byte, err // batchResultsByKey retrieves batch results for Gemini for a single key. func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // We need to get the full batch response with results, so make the API call directly req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3535,20 +3242,17 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Check if batch is still processing @@ -3556,7 +3260,6 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("batch %s is still processing (state: %s), results not yet available", request.BatchID, geminiResp.Metadata.State), nil, - providerName, ) } @@ -3644,9 +3347,7 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3665,14 +3366,12 @@ func (provider *GeminiProvider) BatchResults(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil) } // Try each key until we get results @@ -3696,10 +3395,8 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart request @@ -3709,14 +3406,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add file metadata as JSON metadataField, err := writer.CreateFormField("metadata") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err) } metadataJSON, err := providerUtils.SetJSONField([]byte(`{}`), "file.displayName", request.Filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } if _, err := metadataField.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } // Add file content @@ -3726,14 +3423,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3764,15 +3461,12 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileUploadRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse response - wrapped in "file" object @@ -3780,7 +3474,7 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche File GeminiFileResponse `json:"file"` } if err := sonic.Unmarshal(body, &responseWrapper); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } geminiResp := responseWrapper.File @@ -3816,17 +3510,13 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } // fileListByKey lists files from Gemini for a single key. func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3863,20 +3553,17 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost response @@ -3885,9 +3572,7 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s Data: make([]schemas.FileObject, len(geminiResp.Files)), HasMore: geminiResp.NextPageToken != "", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3940,16 +3625,14 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil) } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3960,10 +3643,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3995,9 +3674,7 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -4009,8 +3686,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch // fileRetrieveByKey retrieves file metadata from Gemini for a single key. func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4041,20 +3716,17 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var sizeBytes int64 @@ -4091,9 +3763,7 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4104,14 +3774,12 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil) } // Try each key until we find the file @@ -4131,8 +3799,6 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // fileDeleteByKey deletes a file from Gemini for a single key. func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4163,10 +3829,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key // Handle error response - DELETE returns 200 with empty body on success if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostFileDeleteResponse{ @@ -4174,9 +3837,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4187,14 +3848,12 @@ func (provider *GeminiProvider) FileDelete(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil) } // Try each key until deletion succeeds @@ -4220,14 +3879,11 @@ func (provider *GeminiProvider) FileContent(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - // Gemini doesn't support direct file content download // Files are referenced by their URI in requests return nil, providerUtils.NewBifrostOperationError( "Gemini Files API doesn't support direct content download. Use the file URI in your requests instead.", nil, - providerName, ) } @@ -4258,7 +3914,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiResponsesRequest(request) }, - provider.GetProviderKey(), ) if bifrostErr != nil { return nil, bifrostErr @@ -4270,14 +3925,13 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonData, _ = providerUtils.DeleteJSONField(jsonData, "systemInstruction") } - providerName := provider.GetProviderKey() req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) if strings.TrimSpace(request.Model) == "" { - return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model"), providerName) + return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model")) } // Determine native model name (e.g., parse any provider prefix) @@ -4310,15 +3964,12 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch } if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } responseBody := append([]byte(nil), body...) @@ -4338,9 +3989,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch response := geminiResponse.ToBifrostCountTokensResponse(request.Model) // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.CountTokensRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -4443,7 +4091,7 @@ func (provider *GeminiProvider) Passthrough( headers := providerUtils.ExtractProviderResponseHeaders(resp) body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -4457,9 +4105,6 @@ func (provider *GeminiProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4523,9 +4168,9 @@ func (provider *GeminiProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -4536,7 +4181,6 @@ func (provider *GeminiProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -4548,11 +4192,7 @@ func (provider *GeminiProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4564,9 +4204,9 @@ func (provider *GeminiProvider) PassthroughStream( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -4615,7 +4255,7 @@ func (provider *GeminiProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/gemini/images.go b/core/providers/gemini/images.go index b390537b3e..881988392a 100644 --- a/core/providers/gemini/images.go +++ b/core/providers/gemini/images.go @@ -408,10 +408,10 @@ func ToGeminiImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRe // Handle size conversion if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" { - imageSize, aspectRatio := convertSizeToImagenFormat(*bifrostReq.Params.Size) + aspectRatio, imageSize := utils.ConvertSizeToAspectRatioAndResolution(*bifrostReq.Params.Size) if imageSize != "" && aspectRatio != "" { geminiReq.GenerationConfig.ImageConfig = &GeminiImageConfig{ - ImageSize: imageSize, + ImageSize: strings.ToLower(imageSize), AspectRatio: aspectRatio, } } @@ -513,9 +513,10 @@ func ToImagenImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRe // Handle size conversion if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" { - imageSize, aspectRatio := convertSizeToImagenFormat(*bifrostReq.Params.Size) + aspectRatio, imageSize := utils.ConvertSizeToAspectRatioAndResolution(*bifrostReq.Params.Size) if imageSize != "" { - req.Parameters.SampleImageSize = &imageSize + imageSizeLower := strings.ToLower(imageSize) + req.Parameters.SampleImageSize = &imageSizeLower } if aspectRatio != "" { req.Parameters.AspectRatio = &aspectRatio @@ -638,55 +639,6 @@ func convertOutputFormatToMimeType(outputFormat string) string { } } -// convertSizeToImagenFormat converts standard size format (e.g., "1024x1024") to Imagen format -// Returns (imageSize, aspectRatio) where imageSize is "1k", "2k", "4k" and aspectRatio is one of: -// "1:1", "3:4", "4:3", "9:16", or "16:9" -func convertSizeToImagenFormat(size string) (string, string) { - // Parse size string (format: "WIDTHxHEIGHT") - parts := strings.Split(size, "x") - if len(parts) != 2 { - return "", "" - } - - width, err1 := strconv.Atoi(parts[0]) - height, err2 := strconv.Atoi(parts[1]) - if err1 != nil || err2 != nil { - return "", "" - } - - // Validate width and height are positive integers - if width <= 0 || height <= 0 { - return "", "" - } - - var imageSize string - if width <= 1024 && height <= 1024 { - imageSize = "1k" - } else if width <= 2048 && height <= 2048 { - imageSize = "2k" - } else if width <= 4096 && height <= 4096 { - imageSize = "4k" - } - - // Calculate aspect ratio - var aspectRatio string - ratio := float64(width) / float64(height) - - // Common aspect ratios with tolerance - if ratio >= 0.99 && ratio <= 1.01 { - aspectRatio = "1:1" - } else if ratio >= 0.74 && ratio <= 0.76 { - aspectRatio = "3:4" - } else if ratio >= 1.32 && ratio <= 1.34 { - aspectRatio = "4:3" - } else if ratio >= 0.56 && ratio <= 0.57 { - aspectRatio = "9:16" - } else if ratio >= 1.77 && ratio <= 1.78 { - aspectRatio = "16:9" - } - - return imageSize, aspectRatio -} // ToBifrostImageGenerationResponse converts an Imagen response to Bifrost format func (response *GeminiImagenResponse) ToBifrostImageGenerationResponse() *schemas.BifrostImageGenerationResponse { diff --git a/core/providers/gemini/models.go b/core/providers/gemini/models.go index 4c8f83c364..7b9f6410eb 100644 --- a/core/providers/gemini/models.go +++ b/core/providers/gemini/models.go @@ -1,9 +1,9 @@ package gemini import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -17,7 +17,7 @@ func toGeminiModelResourceName(modelID string) string { return "models/" + modelID } -func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -26,45 +26,47 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - includedModels := make(map[string]bool) - for _, model := range response.Models { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + included := make(map[string]bool) + + for _, model := range response.Models { contextLength := model.InputTokenLimit + model.OutputTokenLimit - // Remove prefix models/ from model.Name + // Gemini returns model names with a "models/" prefix — strip it before filtering + // so that allowedModels entries like "gemini-1.5-pro" match correctly. modelName := strings.TrimPrefix(model.Name, "models/") - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelName) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, modelName) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelName, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - ContextLength: schemas.Ptr(int(contextLength)), - MaxInputTokens: schemas.Ptr(model.InputTokenLimit), - MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), - SupportedMethods: model.SupportedGenerationMethods, - }) - includedModels[modelName] = true - } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + for _, result := range pipeline.FilterModel(modelName) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + ContextLength: schemas.Ptr(int(contextLength)), + MaxInputTokens: schemas.Ptr(model.InputTokenLimit), + MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), + SupportedMethods: model.SupportedGenerationMethods, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index 29935ca23b..a30eadcaf5 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -17,11 +17,13 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level -const DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation -const DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini -const DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini -const skipThoughtSignatureValidator = "skip_thought_signature_validator" +const ( + MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level + DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation + DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini + DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini + skipThoughtSignatureValidator = "skip_thought_signature_validator" +) type thinkingBudgetRange struct { Min int @@ -523,8 +525,7 @@ type GoogleMaps struct { } // URLContext is a tool to support URL context retrieval. -type URLContext struct { -} +type URLContext struct{} // ToolComputerUse is a tool to support computer use. type ToolComputerUse struct { @@ -569,8 +570,7 @@ type ExternalAPIElasticSearchParams struct { } // ExternalAPISimpleSearchParams represents the search parameters to use for SIMPLE_SEARCH spec. -type ExternalAPISimpleSearchParams struct { -} +type ExternalAPISimpleSearchParams struct{} // ExternalAPI retrieves from data source powered by external API for grounding. The external API // is not owned by Google, but needs to follow the pre-defined API spec. @@ -728,8 +728,7 @@ type Retrieval struct { // ToolCodeExecution is a tool that executes code generated by the model, and automatically returns the result // to the model. See also [ExecutableCode]and [CodeExecutionResult] which are input // and output to this tool. -type ToolCodeExecution struct { -} +type ToolCodeExecution struct{} // Tool details of a tool that the model may use to generate a response. type Tool struct { diff --git a/core/providers/gemini/videos.go b/core/providers/gemini/videos.go index 62ce110c26..43ece90be4 100644 --- a/core/providers/gemini/videos.go +++ b/core/providers/gemini/videos.go @@ -217,7 +217,7 @@ func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Gemini operation response to Bifrost format func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if operation == nil { - return nil, providerUtils.NewBifrostOperationError("operation is nil", nil, schemas.Gemini) + return nil, providerUtils.NewBifrostOperationError("operation is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index c1362fbece..f152a10e94 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -149,9 +149,6 @@ func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/groq/groq_test.go b/core/providers/groq/groq_test.go index 54b7945d76..3bf59588db 100644 --- a/core/providers/groq/groq_test.go +++ b/core/providers/groq/groq_test.go @@ -38,28 +38,28 @@ func TestGroq(t *testing.T) { TranscriptionModel: "whisper-large-v3", SpeechSynthesisModel: "canopylabs/orpheus-v1-english", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, - TextCompletionStream: false, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, + TextCompletionStream: false, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - FileBase64: false, // Not supported - FileURL: false, // Not supported - CompleteEnd2End: true, - Embedding: false, - ListModels: true, - Reasoning: true, - Transcription: true, - SpeechSynthesis: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + FileBase64: false, // Not supported + FileURL: false, // Not supported + CompleteEnd2End: true, + Embedding: false, + ListModels: true, + Reasoning: true, + Transcription: true, + SpeechSynthesis: true, }, } t.Run("GroqTests", func(t *testing.T) { diff --git a/core/providers/huggingface/errors.go b/core/providers/huggingface/errors.go index 49ce427df7..d98357e0a8 100644 --- a/core/providers/huggingface/errors.go +++ b/core/providers/huggingface/errors.go @@ -10,7 +10,7 @@ import ( ) // parseHuggingFaceImageError parses HuggingFace error responses -func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseHuggingFaceImageError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp HuggingFaceResponseError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -53,13 +53,5 @@ func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.Req bifrostErr.Error.Message = errorResp.Error } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index 4039f2d36b..110f5d6574 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -254,12 +254,12 @@ func (provider *HuggingFaceProvider) completeRequest(ctx *schemas.BifrostContext // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp, nil) + return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -325,7 +325,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName)} + resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)} return } @@ -384,7 +384,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext } if result.response != nil { - providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, request.Unfiltered) + providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerResponse != nil { aggregatedResponse.Data = append(aggregatedResponse.Data, providerResponse.Data...) totalLatency += result.latency @@ -459,10 +459,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionRequest, - }, } } if inferenceProvider != "" { @@ -483,8 +479,7 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, reqBody.Stream = schemas.Ptr(false) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -518,9 +513,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse.Object = "chat.completion" } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -550,10 +542,6 @@ func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostCo Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionStreamRequest, - }, } } if inferenceProvider != "" { @@ -610,9 +598,6 @@ func (provider *HuggingFaceProvider) Responses(ctx *schemas.BifrostContext, key } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -644,10 +629,6 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.EmbeddingRequest, - }, } } @@ -657,8 +638,7 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceEmbeddingRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -698,13 +678,10 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key // Unmarshal directly to BifrostEmbeddingResponse with custom logic bifrostResponse, convErr := UnmarshalHuggingFaceEmbeddingResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -735,10 +712,6 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.SpeechRequest, - }, } } @@ -747,8 +720,7 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -784,18 +756,15 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch // Download the audio file from the URL audioData, downloadErr := provider.downloadAudioFromURL(ctx, response.Audio.URL) if downloadErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse, convErr := response.ToBifrostSpeechResponse(request.Model, audioData) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.SpeechRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -838,10 +807,6 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.TranscriptionRequest, - }, } } @@ -851,7 +816,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, isHFInferenceAudioRequest := inferenceProvider == hfInference if inferenceProvider == hfInference { if request.Input == nil || len(request.Input.File) == 0 { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests"), provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests")) } jsonData = request.Input.File } else { @@ -861,8 +826,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceTranscriptionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -910,13 +874,10 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, bifrostResponse, convErr := response.ToBifrostTranscriptionResponse(request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TranscriptionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -950,10 +911,6 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationRequest, - }, } } @@ -963,8 +920,7 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageGenerationRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -1004,15 +960,12 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext // Unmarshal response using Nebius converter bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1044,10 +997,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationStreamRequest, - }, } } @@ -1055,11 +1004,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image generation streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } - providerName := provider.GetProviderKey() // Set headers headers := map[string]string{ @@ -1077,8 +1023,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageStreamRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1110,9 +1055,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1128,9 +1070,9 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1139,11 +1081,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1167,9 +1105,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1190,6 +1126,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1208,14 +1146,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1238,11 +1169,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1268,11 +1194,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageGenerationEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1312,11 +1235,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Type: schemas.ImageGenerationEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ @@ -1360,10 +1280,6 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditRequest, - }, } } @@ -1378,8 +1294,7 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageEditRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -1415,15 +1330,12 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key // Unmarshal response bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1455,10 +1367,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditStreamRequest, - }, } } @@ -1466,9 +1374,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image edit streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } var authHeader map[string]string @@ -1494,15 +1400,13 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageEditRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1530,9 +1434,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1548,9 +1449,9 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1559,11 +1460,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1587,9 +1484,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1610,6 +1505,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1628,14 +1525,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1658,11 +1548,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1688,11 +1573,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageEditEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1732,11 +1614,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Type: schemas.ImageEditEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ diff --git a/core/providers/huggingface/huggingface_test.go b/core/providers/huggingface/huggingface_test.go index 4a1a91802d..4c98cec1f0 100644 --- a/core/providers/huggingface/huggingface_test.go +++ b/core/providers/huggingface/huggingface_test.go @@ -51,7 +51,7 @@ func TestHuggingface(t *testing.T) { ImageBase64: true, MultipleImages: true, CompleteEnd2End: true, - Embedding: true, + Embedding: false, Transcription: true, TranscriptionStream: false, SpeechSynthesis: true, diff --git a/core/providers/huggingface/models.go b/core/providers/huggingface/models.go index c637c3b6f6..de615ccec2 100644 --- a/core/providers/huggingface/models.go +++ b/core/providers/huggingface/models.go @@ -5,6 +5,7 @@ import ( "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -13,7 +14,7 @@ const ( maxModelFetchLimit = 1000 ) -func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -22,15 +23,20 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi Data: make([]schemas.Model, 0, len(response.Models)), } - var blacklisted map[string]struct{} - if !unfiltered && len(blacklistedModels) > 0 { - blacklisted = make(map[string]struct{}, len(blacklistedModels)) - for _, m := range blacklistedModels { - blacklisted[m] = struct{}{} - } + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - includedModels := make(map[string]bool) + included := make(map[string]bool) + for _, model := range response.Models { if model.ModelID == "" { continue @@ -41,39 +47,33 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi continue } - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { - continue - } - if _, ok := blacklisted[model.ModelID]; ok { - continue - } - - newModel := schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, model.ModelID), - Name: schemas.Ptr(model.ModelID), - SupportedMethods: supported, - HuggingFaceID: schemas.Ptr(model.ID), - } - - bifrostResponse.Data = append(bifrostResponse.Data, newModel) - includedModels[model.ModelID] = true - } - - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if _, ok := blacklisted[allowedModel]; ok { - continue + // Aliases apply at the model level (model.ModelID), not at the compound + // "{providerKey}/{inferenceProvider}/{modelID}" level. + for _, result := range pipeline.FilterModel(model.ModelID) { + newModel := schemas.Model{ + // inferenceProvider stays in the compound ID; aliases rename only the model segment + ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, result.ResolvedID), + Name: schemas.Ptr(model.ModelID), + SupportedMethods: supported, + HuggingFaceID: schemas.Ptr(model.ID), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, allowedModel), - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + newModel.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, newModel) + included[strings.ToLower(result.ResolvedID)] = true } } + // Backfill: use standard pipeline. Note that backfilled HF entries use a simplified + // compound ID since we don't know which inferenceProvider to assign them to. + for _, m := range pipeline.BackfillModels(included) { + // Re-wrap the backfill ID to include the inferenceProvider segment + rawID := strings.TrimPrefix(m.ID, string(providerKey)+"/") + m.ID = fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, rawID) + bifrostResponse.Data = append(bifrostResponse.Data, m) + } + return bifrostResponse } diff --git a/core/providers/huggingface/responses.go b/core/providers/huggingface/responses.go index fd68aa76a8..35ad2c336d 100644 --- a/core/providers/huggingface/responses.go +++ b/core/providers/huggingface/responses.go @@ -43,9 +43,6 @@ func ToBifrostResponsesResponseFromHuggingFace(resp *schemas.BifrostChatResponse responsesResp := resp.ToBifrostResponsesResponse() if responsesResp != nil { - responsesResp.ExtraFields.Provider = schemas.HuggingFace - responsesResp.ExtraFields.ModelRequested = requestedModel - responsesResp.ExtraFields.RequestType = schemas.ResponsesRequest } return responsesResp, nil diff --git a/core/providers/huggingface/speech.go b/core/providers/huggingface/speech.go index 65c0ba6e12..f702d1f39f 100644 --- a/core/providers/huggingface/speech.go +++ b/core/providers/huggingface/speech.go @@ -125,10 +125,6 @@ func (response *HuggingFaceSpeechResponse) ToBifrostSpeechResponse(requestedMode // Create the base Bifrost response with the downloaded audio data bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: audioData, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Note: HuggingFace TTS API typically doesn't return usage information diff --git a/core/providers/huggingface/transcription.go b/core/providers/huggingface/transcription.go index 0d892cb07c..f3ff5c293a 100644 --- a/core/providers/huggingface/transcription.go +++ b/core/providers/huggingface/transcription.go @@ -144,10 +144,6 @@ func (response *HuggingFaceTranscriptionResponse) ToBifrostTranscriptionResponse // Create the base Bifrost response bifrostResponse := &schemas.BifrostTranscriptionResponse{ Text: response.Text, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Map chunks to segments if available diff --git a/core/providers/huggingface/utils.go b/core/providers/huggingface/utils.go index ad68e4c6ac..b96210c832 100644 --- a/core/providers/huggingface/utils.go +++ b/core/providers/huggingface/utils.go @@ -221,8 +221,6 @@ func convertToInferenceProviderMappings(resp *HuggingFaceInferenceProviderMappin } func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx context.Context, huggingfaceModelName string) (map[inferenceProvider]HuggingFaceInferenceProviderMapping, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Check cache first if cached, ok := provider.modelProviderMappingCache.Load(huggingfaceModelName); ok { if mappings, ok := cached.(map[inferenceProvider]HuggingFaceInferenceProviderMapping); ok { @@ -259,12 +257,12 @@ func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx contex body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var mappingResp HuggingFaceInferenceProviderMappingResponse if err := sonic.Unmarshal(body, &mappingResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } mappings := convertToInferenceProviderMappings(&mappingResp) diff --git a/core/providers/mistral/custom_provider_test.go b/core/providers/mistral/custom_provider_test.go index 0d8e283f51..9015230544 100644 --- a/core/providers/mistral/custom_provider_test.go +++ b/core/providers/mistral/custom_provider_test.go @@ -25,7 +25,7 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { resp.SetStatusCode(http.StatusBadRequest) resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`) - bifrostErr := ParseMistralError(resp, schemas.OCRRequest, customMistralProviderName, "mistral-ocr-latest") + bifrostErr := ParseMistralError(resp) require.NotNil(t, bifrostErr) require.NotNil(t, bifrostErr.Error) @@ -33,8 +33,6 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { assert.Equal(t, schemas.Ptr("invalid_request_error"), bifrostErr.Error.Type) assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code) assert.Equal(t, customMistralProviderName, bifrostErr.ExtraFields.Provider) - assert.Equal(t, schemas.OCRRequest, bifrostErr.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.ModelRequested) } func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) { @@ -156,5 +154,5 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) require.NotNil(t, response) assert.Equal(t, customMistralProviderName, response.ExtraFields.Provider) - assert.Equal(t, "codestral-embed", response.ExtraFields.ModelRequested) + assert.Equal(t, "codestral-embed", response.ExtraFields.ResolvedModelUsed) } diff --git a/core/providers/mistral/errors.go b/core/providers/mistral/errors.go index 2ae260eb9a..cbbd40a560 100644 --- a/core/providers/mistral/errors.go +++ b/core/providers/mistral/errors.go @@ -19,7 +19,7 @@ type MistralErrorResponse struct { } // ParseMistralError parses Mistral-specific error responses. -func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseMistralError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp MistralErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if bifrostErr == nil { @@ -67,9 +67,5 @@ func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, } } - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index ec699ddcc2..1999cbb5fb 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -76,8 +76,6 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -103,7 +101,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseMistralError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseMistralError(resp) return nil, bifrostErr } @@ -118,7 +116,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Create final response - response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels) + response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -228,9 +226,6 @@ func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key sche } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -278,25 +273,23 @@ func (provider *MistralProvider) Rerank(ctx *schemas.BifrostContext, key schemas // OCR performs an OCR request to the Mistral API. // It sends a JSON request to Mistral's OCR endpoint and returns the extracted content. func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Convert Bifrost request to Mistral format mistralReq := ToMistralOCRRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil) } // Marshal request body requestBody, err := sonic.Marshal(mistralReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge extra params into JSON payload if len(mistralReq.ExtraParams) > 0 { requestBody, err = providerUtils.MergeExtraParamsIntoJSON(requestBody, mistralReq.ExtraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -332,13 +325,12 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.OCRRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -366,20 +358,17 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostOCRResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.RequestType = schemas.OCRRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -401,16 +390,14 @@ func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postH // It creates a multipart form with the audio file and sends it to Mistral's transcription endpoint. // Returns the transcribed text and metadata, or an error if the request fails. func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form body - body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, providerName) + body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr } @@ -442,13 +429,12 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -476,20 +462,17 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostTranscriptionResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.RequestType = schemas.TranscriptionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -511,7 +494,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } mistralReq.Stream = schemas.Ptr(true) @@ -566,9 +549,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -577,8 +560,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -597,9 +579,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -639,7 +621,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) } break } @@ -687,11 +669,6 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -720,11 +697,8 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( // Set extra fields response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(*lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(*lastChunkTime).Milliseconds(), } *lastChunkTime = time.Now() diff --git a/core/providers/mistral/models.go b/core/providers/mistral/models.go index ef3e5934c1..8d5fd7f3d6 100644 --- a/core/providers/mistral/models.go +++ b/core/providers/mistral/models.go @@ -1,12 +1,13 @@ package mistral import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { +func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,40 +16,40 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { - continue - } - if slices.Contains(blacklistedModels, model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + model.ID, - Name: schemas.Ptr(model.Name), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.Created), - ContextLength: schemas.Ptr(int(model.MaxContextLength)), - OwnedBy: schemas.Ptr(model.OwnedBy), - }) - includedModels[model.ID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Mistral, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Mistral) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.Created), + ContextLength: schemas.Ptr(int(model.MaxContextLength)), + OwnedBy: schemas.Ptr(model.OwnedBy), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - includedModels[allowedModel] = true + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/mistral/ocr_test.go b/core/providers/mistral/ocr_test.go index c88bca5f48..ccf7223e4c 100644 --- a/core/providers/mistral/ocr_test.go +++ b/core/providers/mistral/ocr_test.go @@ -436,9 +436,6 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, 1, resp.Pages[1].Index) require.NotNil(t, resp.UsageInfo) assert.Equal(t, 2, resp.UsageInfo.PagesProcessed) - assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) - assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.ModelRequested) }, }, { @@ -503,9 +500,6 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, "server_error", *err.Error.Type) require.NotNil(t, err.Error.Code) assert.Equal(t, "internal_error", *err.Error.Code) - assert.Equal(t, schemas.Mistral, err.ExtraFields.Provider) - assert.Equal(t, schemas.OCRRequest, err.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.ModelRequested) }, }, { @@ -757,7 +751,5 @@ func TestMistralOCRIntegration(t *testing.T) { require.NotEmpty(t, resp.Pages, "Expected at least one page") assert.Equal(t, 0, resp.Pages[0].Index) assert.NotEmpty(t, resp.Pages[0].Markdown, "Expected non-empty markdown for page 0") - assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) assert.Greater(t, resp.ExtraFields.Latency, int64(0)) } diff --git a/core/providers/mistral/transcription.go b/core/providers/mistral/transcription.go index a4a018e5c6..fe9b262126 100644 --- a/core/providers/mistral/transcription.go +++ b/core/providers/mistral/transcription.go @@ -109,58 +109,58 @@ func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, req *Mi } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(req.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } // Add model field (required) if err := writer.WriteField("model", req.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add stream field if streaming if req.Stream != nil && *req.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } // Add optional fields if req.Language != nil { if err := writer.WriteField("language", *req.Language); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language field", err) } } if req.Prompt != nil { if err := writer.WriteField("prompt", *req.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } } if req.ResponseFormat != nil { if err := writer.WriteField("response_format", *req.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if req.Temperature != nil { if err := writer.WriteField("temperature", formatFloat64(*req.Temperature)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range req.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } // Close the multipart writer to finalize the form if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/nebius/errors.go b/core/providers/nebius/errors.go index 98d0fb78d8..de8bcf0d84 100644 --- a/core/providers/nebius/errors.go +++ b/core/providers/nebius/errors.go @@ -9,7 +9,7 @@ import ( ) // parseNebiusImageError parses Nebius error responses -func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseNebiusImageError(resp *fasthttp.Response) *schemas.BifrostError { var nebiusErr NebiusError bifrostErr := providerUtils.HandleProviderAPIError(resp, &nebiusErr) @@ -60,13 +60,5 @@ func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestM bifrostErr.Error.Message = message } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index 998588ce71..eac617df6e 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -193,9 +193,6 @@ func (provider *NebiusProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -265,16 +262,15 @@ func (provider *NebiusProvider) TranscriptionStream(ctx *schemas.BifrostContext, func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Validate request is not nil if request == nil { - return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil) } // Validate input and prompt are not nil/empty if request.Input == nil || strings.TrimSpace(request.Input.Prompt) == "" { - return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil) } path := providerUtils.GetPathFromContext(ctx, "/v1/images/generations") - providerName := schemas.Nebius // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -309,8 +305,7 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return provider.ToNebiusImageGenerationRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -328,16 +323,12 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp), jsonData, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } response := &schemas.BifrostImageGenerationResponse{} @@ -357,9 +348,6 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled diff --git a/core/providers/nebius/nebius_test.go b/core/providers/nebius/nebius_test.go index 898617cb1e..da6c4065e8 100644 --- a/core/providers/nebius/nebius_test.go +++ b/core/providers/nebius/nebius_test.go @@ -32,25 +32,25 @@ func TestNebius(t *testing.T) { EmbeddingModel: "BAAI/bge-en-icl", ImageGenerationModel: "black-forest-labs/flux-schnell", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - ImageGeneration: true, - CompleteEnd2End: true, - ImageGenerationStream: false, - Embedding: true, // Nebius supports embeddings - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + ImageGeneration: true, + CompleteEnd2End: true, + ImageGenerationStream: false, + Embedding: true, // Nebius supports embeddings + ListModels: true, }, } diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index 6710169160..b84d3f7c9c 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -3,7 +3,6 @@ package ollama import ( - "fmt" "strings" "time" @@ -50,11 +49,7 @@ func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") - // BaseURL is required for Ollama - if config.NetworkConfig.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for ollama provider") - } - + // BaseURL is optional when keys have ollama_key_config with per-key URLs return &OllamaProvider{ logger: logger, client: client, @@ -69,17 +64,14 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { return schemas.Ollama } -// ListModels performs a list models request to Ollama's API. -func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) - } - return openai.HandleOpenAIListModelsRequest( +// listModelsByKey performs a list models request for a single Ollama key. +func (provider *OllamaProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return openai.ListModelsByKey( ctx, provider.client, - request, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"), - keys, + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), + key, + request.Unfiltered, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -87,12 +79,24 @@ func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []s ) } +// ListModels performs a list models request to Ollama's API. +// Requests are made concurrently per key so that each backend is queried +// with its own URL (from ollama_key_config). +func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return providerUtils.HandleMultipleListModelsRequests( + ctx, + keys, + request, + provider.listModelsByKey, + ) +} + // TextCompletion performs a text completion request to the Ollama API. func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -112,7 +116,7 @@ func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/completions", + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -132,7 +136,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -154,7 +158,7 @@ func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/chat/completions", + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -179,9 +183,6 @@ func (provider *OllamaProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -202,7 +203,7 @@ func (provider *OllamaProvider) Embedding(ctx *schemas.BifrostContext, key schem return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, diff --git a/core/providers/ollama/ollama_test.go b/core/providers/ollama/ollama_test.go index ad31297046..a9005c1220 100644 --- a/core/providers/ollama/ollama_test.go +++ b/core/providers/ollama/ollama_test.go @@ -29,24 +29,24 @@ func TestOllama(t *testing.T) { TextModel: "", // Ollama doesn't support text completion in newer models EmbeddingModel: "", // Ollama doesn't support embedding Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - FileBase64: false, - FileURL: false, - CompleteEnd2End: true, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + FileBase64: false, + FileURL: false, + CompleteEnd2End: true, + Embedding: false, + ListModels: true, }, } diff --git a/core/providers/openai/batch.go b/core/providers/openai/batch.go index ae095e5c77..ec8ce468bb 100644 --- a/core/providers/openai/batch.go +++ b/core/providers/openai/batch.go @@ -10,10 +10,10 @@ import ( // OpenAIBatchRequest represents the request body for creating a batch. type OpenAIBatchRequest struct { - InputFileID string `json:"input_file_id"` - Endpoint string `json:"endpoint"` - CompletionWindow string `json:"completion_window"` - Metadata map[string]string `json:"metadata,omitempty"` + InputFileID string `json:"input_file_id"` + Endpoint string `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]string `json:"metadata,omitempty"` OutputExpiresAfter *schemas.BatchExpiresAfter `json:"output_expires_after,omitempty"` } @@ -82,7 +82,7 @@ func ToBifrostBatchStatus(status string) schemas.BatchStatus { } // ToBifrostBatchCreateResponse converts OpenAI batch response to Bifrost batch response. -func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, Object: r.Object, @@ -95,9 +95,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. OutputFileID: r.OutputFileID, ErrorFileID: r.ErrorFileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -125,7 +123,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. } // ToBifrostBatchRetrieveResponse converts OpenAI batch response to Bifrost batch retrieve response. -func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: r.Object, @@ -146,9 +144,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema ErrorFileID: r.ErrorFileID, Errors: r.Errors, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -174,35 +170,3 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema return resp } - -// splitJSONL splits JSONL content into individual lines. -func splitJSONL(data []byte) [][]byte { - var lines [][]byte - start := 0 - for i, b := range data { - if b == '\n' { - if i > start { - end := i - // Strip trailing \r if present (handle CRLF) - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - start = i + 1 - } - } - if start < len(data) { - end := len(data) - // Strip trailing \r if present - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - return lines -} diff --git a/core/providers/openai/chat_test.go b/core/providers/openai/chat_test.go index f7f0e15e95..724c438d91 100644 --- a/core/providers/openai/chat_test.go +++ b/core/providers/openai/chat_test.go @@ -305,7 +305,6 @@ func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *tes func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) diff --git a/core/providers/openai/errors.go b/core/providers/openai/errors.go index 6a5bc1ce08..69d0aff407 100644 --- a/core/providers/openai/errors.go +++ b/core/providers/openai/errors.go @@ -10,10 +10,10 @@ import ( ) // ErrorConverter is a function that converts provider-specific error responses to BifrostError. -type ErrorConverter func(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError +type ErrorConverter func(resp *fasthttp.Response) *schemas.BifrostError // ParseOpenAIError parses OpenAI error responses. -func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp schemas.BifrostError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -49,11 +49,6 @@ func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, } // Set ExtraFields unconditionally so provider/model/request metadata is always attached - if bifrostErr != nil { - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - } return bifrostErr } diff --git a/core/providers/openai/errors_test.go b/core/providers/openai/errors_test.go index f33008600b..1132a92723 100644 --- a/core/providers/openai/errors_test.go +++ b/core/providers/openai/errors_test.go @@ -12,7 +12,7 @@ func TestParseOpenAIError_FallbackMessageWhenProviderBodyIsNonOpenAIShape(t *tes resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"detail":[{"loc":["body","messages",0,"role"],"msg":"value is not a valid enumeration member"}]}`) - errResp := ParseOpenAIError(&resp, schemas.ResponsesStreamRequest, schemas.Cerebras, "llama3.1-8b") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -29,7 +29,7 @@ func TestParseOpenAIError_PreservesProviderMessageWhenPresent(t *testing.T) { resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"error":{"message":"unsupported role: developer","type":"invalid_request_error","param":"messages.0.role","code":"invalid_value"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -43,7 +43,7 @@ func TestParseOpenAIError_FallbackMessageWhenBodyIsEmpty(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBody(nil) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -59,7 +59,7 @@ func TestParseOpenAIError_WhitespaceProviderMessageFallsBack(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBodyString(`{"error":{"message":" ","type":"invalid_request_error"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -73,7 +73,7 @@ func TestParseOpenAIError_DefaultStatusCodeFallsBackWithStatusNumber(t *testing. // fasthttp defaults zero-value response status code to 200. resp.SetBodyString(`{"error":{"message":""}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } diff --git a/core/providers/openai/files.go b/core/providers/openai/files.go index bbaf2b2f70..133250cac7 100644 --- a/core/providers/openai/files.go +++ b/core/providers/openai/files.go @@ -55,7 +55,7 @@ func ToBifrostFileStatus(status string) schemas.FileStatus { } // ToBifrostFileUploadResponse converts OpenAI file response to Bifrost file upload response. -func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Object, @@ -67,9 +67,7 @@ func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mo StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -97,9 +95,7 @@ func (r *OpenAIFileResponse) ToBifrostFileRetrieveResponse(providerName schemas. StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/openai/images.go b/core/providers/openai/images.go index f183e17ec5..9176f1e1e7 100644 --- a/core/providers/openai/images.go +++ b/core/providers/openai/images.go @@ -125,18 +125,18 @@ func ToOpenAIImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) *Open func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageEditRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add prompt field (required) if err := writer.WriteField("prompt", openaiReq.Input.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add stream field when requesting streaming if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } @@ -168,71 +168,71 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err) } if _, err := part.Write(imageInput.Image); err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err) } } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Quality != nil { if err := writer.WriteField("quality", *openaiReq.Quality); err != nil { - return providerUtils.NewBifrostOperationError("failed to write quality field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write quality field", err) } } if openaiReq.Background != nil { if err := writer.WriteField("background", *openaiReq.Background); err != nil { - return providerUtils.NewBifrostOperationError("failed to write background field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write background field", err) } } if openaiReq.InputFidelity != nil { if err := writer.WriteField("input_fidelity", *openaiReq.InputFidelity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err) } } if openaiReq.PartialImages != nil { if err := writer.WriteField("partial_images", strconv.Itoa(*openaiReq.PartialImages)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write partial_images field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write partial_images field", err) } } if openaiReq.OutputFormat != nil { if err := writer.WriteField("output_format", *openaiReq.OutputFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_format field", err) } } if openaiReq.OutputCompression != nil { if err := writer.WriteField("output_compression", strconv.Itoa(*openaiReq.OutputCompression)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_compression field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_compression field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } @@ -260,16 +260,16 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {maskMimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create mask form part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create mask form part", err) } if _, err := maskPart.Write(openaiReq.Mask); err != nil { - return providerUtils.NewBifrostOperationError("failed to write mask data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write mask data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil @@ -299,12 +299,12 @@ func ToOpenAIImageVariationRequest(bifrostReq *schemas.BifrostImageVariationRequ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageVariationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add image file (required) if openaiReq.Input == nil || openaiReq.Input.Image.Image == nil || len(openaiReq.Input.Image.Image) == 0 { - return providerUtils.NewBifrostOperationError("image is required", nil, providerName) + return providerUtils.NewBifrostOperationError("image is required", nil) } // Detect MIME type @@ -320,41 +320,41 @@ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openai "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create image part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create image part", err) } if _, err := part.Write(openaiReq.Input.Image.Image); err != nil { - return providerUtils.NewBifrostOperationError("failed to write image data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write image data", err) } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/large_payload.go b/core/providers/openai/large_payload.go index 461f3417de..fe3aaf1812 100644 --- a/core/providers/openai/large_payload.go +++ b/core/providers/openai/large_payload.go @@ -42,8 +42,6 @@ func handleOpenAILargePayloadPassthrough( key schemas.Key, extraHeaders map[string]string, providerName schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) (*largePayloadResult, *schemas.BifrostError, bool) { isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) @@ -91,7 +89,7 @@ func handleOpenAILargePayloadPassthrough( // Error responses are always small — materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - parsedErr := ParseOpenAIError(resp, requestType, providerName, model) + parsedErr := ParseOpenAIError(resp) fasthttp.ReleaseResponse(resp) return nil, parsedErr, true } @@ -126,7 +124,7 @@ func finalizeOpenAIResponse( providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, *largePayloadResult, *schemas.BifrostError) { - body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, logger) + body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, logger) if bifrostErr != nil { fasthttp.ReleaseResponse(resp) return nil, nil, bifrostErr diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index d00a8af112..a76d350d28 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -1,13 +1,14 @@ package openai import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) // ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response -func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -16,38 +17,39 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ID, - Created: model.Created, - OwnedBy: schemas.Ptr(model.OwnedBy), - ContextLength: model.ContextWindow, - }) - includedModels[model.ID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Created: model.Created, + OwnedBy: schemas.Ptr(model.OwnedBy), + ContextLength: model.ContextWindow, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 810ee583d7..a4e06dac47 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -166,7 +166,7 @@ func ListModelsByKey( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseOpenAIError(resp) return nil, bifrostErr } @@ -181,10 +181,8 @@ func ListModelsByKey( return nil, bifrostErr } - response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, unfiltered) + response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, unfiltered) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -289,22 +287,22 @@ func HandleOpenAITextCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TextCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTextCompletionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -313,8 +311,7 @@ func HandleOpenAITextCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -335,9 +332,9 @@ func HandleOpenAITextCompletionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -349,7 +346,7 @@ func HandleOpenAITextCompletionRequest( return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -367,9 +364,6 @@ func HandleOpenAITextCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -455,8 +449,7 @@ func HandleOpenAITextCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -501,9 +494,9 @@ func HandleOpenAITextCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -514,9 +507,9 @@ func HandleOpenAITextCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -534,9 +527,9 @@ func HandleOpenAITextCompletionStreaming( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -558,7 +551,7 @@ func HandleOpenAITextCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -585,7 +578,7 @@ func HandleOpenAITextCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -596,11 +589,6 @@ func HandleOpenAITextCompletionStreaming( rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { // TODO fix this - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -619,11 +607,6 @@ func HandleOpenAITextCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -700,9 +683,6 @@ func HandleOpenAITextCompletionStreaming( if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { chunkIndex++ - response.ExtraFields.RequestType = schemas.TextCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -720,7 +700,7 @@ func HandleOpenAITextCompletionStreaming( } } - response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) + response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -812,22 +792,22 @@ func HandleOpenAIChatCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ChatCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostChatResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -836,8 +816,7 @@ func HandleOpenAIChatCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -859,9 +838,9 @@ func HandleOpenAIChatCompletionRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -873,7 +852,7 @@ func HandleOpenAIChatCompletionRequest( return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } response := &schemas.BifrostChatResponse{} @@ -891,9 +870,6 @@ func HandleOpenAIChatCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ChatCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1009,8 +985,7 @@ func HandleOpenAIChatCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1055,9 +1030,9 @@ func HandleOpenAIChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1068,9 +1043,9 @@ func HandleOpenAIChatCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1083,20 +1058,14 @@ func HandleOpenAIChatCompletionStreaming( // Create response channel responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) - // Determine request type for cleanup - streamRequestType := schemas.ChatCompletionStreamRequest - if isResponsesToChatCompletionsFallback { - streamRequestType = schemas.ResponsesStreamRequest - } - // Start streaming in a goroutine go func() { defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } // Release the responses stream state if it was acquired (for ResponsesToChatCompletions fallback) schemas.ReleaseChatToResponsesStreamState(responsesStreamState) @@ -1120,7 +1089,7 @@ func HandleOpenAIChatCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1134,6 +1103,8 @@ func HandleOpenAIChatCompletionStreaming( var finishReason *string var messageID string + var modelName string + var created int forwardedTerminalFinishReason := false for { @@ -1149,7 +1120,7 @@ func HandleOpenAIChatCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -1162,11 +1133,6 @@ func HandleOpenAIChatCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1180,11 +1146,6 @@ func HandleOpenAIChatCompletionStreaming( if customResponseHandler != nil { rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -1215,11 +1176,6 @@ func HandleOpenAIChatCompletionStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: streamRequestType, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1237,9 +1193,6 @@ func HandleOpenAIChatCompletionStreaming( return } - response.ExtraFields.RequestType = streamRequestType - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber if sendBackRawResponse { @@ -1302,6 +1255,10 @@ func HandleOpenAIChatCompletionStreaming( response.Usage = nil } + if response.Model != "" { + modelName = response.Model + } + // Skip empty responses or responses without choices if len(response.Choices) == 0 { continue @@ -1317,6 +1274,9 @@ func HandleOpenAIChatCompletionStreaming( if response.ID != "" && messageID == "" { messageID = response.ID } + if response.Created != 0 && created == 0 { + created = response.Created + } // Handle regular content chunks, including reasoning if choice.ChatStreamResponseChoice != nil && @@ -1331,9 +1291,6 @@ func HandleOpenAIChatCompletionStreaming( } chunkIndex++ - response.ExtraFields.RequestType = schemas.ChatCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -1357,7 +1314,7 @@ func HandleOpenAIChatCompletionStreaming( if forwardedTerminalFinishReason { finalFinishReason = nil } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, streamRequestType, providerName, request.Model) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, modelName, created) if postResponseConverter != nil { response = postResponseConverter(response) } @@ -1444,21 +1401,21 @@ func HandleOpenAIResponsesRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ResponsesRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostResponsesResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1468,8 +1425,7 @@ func HandleOpenAIResponsesRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1491,9 +1447,9 @@ func HandleOpenAIResponsesRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1504,7 +1460,7 @@ func HandleOpenAIResponsesRequest( if lpResult != nil { return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1522,9 +1478,6 @@ func HandleOpenAIResponsesRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ResponsesRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1621,8 +1574,7 @@ func HandleOpenAIResponsesStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1666,9 +1618,9 @@ func HandleOpenAIResponsesStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1679,9 +1631,9 @@ func HandleOpenAIResponsesStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1699,9 +1651,9 @@ func HandleOpenAIResponsesStreaming( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1723,7 +1675,7 @@ func HandleOpenAIResponsesStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1745,7 +1697,7 @@ func HandleOpenAIResponsesStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1757,11 +1709,6 @@ func HandleOpenAIResponsesStreaming( if customResponseHandler != nil { rawRequest, rawResponse, bifrostErr := customResponseHandler([]byte(jsonData), &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } if sendBackRawRequest { bifrostErr.ExtraFields.RawRequest = rawRequest } @@ -1795,11 +1742,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1832,11 +1774,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeFailed)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Response != nil && response.Response.Error != nil { bifrostErr.Error.Message = response.Response.Error.Message @@ -1847,11 +1784,7 @@ func HandleOpenAIResponsesStreaming( return } - response.ExtraFields.RequestType = schemas.ResponsesStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber - if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete { // Set raw request if enabled if sendBackRawRequest { @@ -1939,22 +1872,22 @@ func HandleOpenAIEmbeddingRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.EmbeddingRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostEmbeddingResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1964,8 +1897,7 @@ func HandleOpenAIEmbeddingRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIEmbeddingRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1986,7 +1918,7 @@ func HandleOpenAIEmbeddingRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.EmbeddingRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1998,7 +1930,7 @@ func HandleOpenAIEmbeddingRequest( return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2016,9 +1948,6 @@ func HandleOpenAIEmbeddingRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.EmbeddingRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2097,22 +2026,21 @@ func HandleOpenAISpeechRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.SpeechRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } // Speech response is raw audio bytes (MP3/WAV), not JSON return &schemas.BifrostSpeechResponse{ Audio: lpResult.ResponseBody, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }, - providerName) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }) if bifrostErr != nil { return nil, bifrostErr } @@ -2133,7 +2061,7 @@ func HandleOpenAISpeechRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Get the binary audio data from the response body @@ -2144,7 +2072,7 @@ func HandleOpenAISpeechRequest( } if lpResult != nil { return &schemas.BifrostSpeechResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2154,9 +2082,6 @@ func HandleOpenAISpeechRequest( bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: body, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2179,7 +2104,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHo for _, model := range providerUtils.UnsupportedSpeechStreamModels { if model == request.Model { - return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil) } } @@ -2264,8 +2189,7 @@ func HandleOpenAISpeechStreamRequest( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2291,9 +2215,9 @@ func HandleOpenAISpeechStreamRequest( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -2303,7 +2227,7 @@ func HandleOpenAISpeechStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -2321,9 +2245,9 @@ func HandleOpenAISpeechStreamRequest( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2345,7 +2269,7 @@ func HandleOpenAISpeechStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2369,7 +2293,7 @@ func HandleOpenAISpeechStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2381,11 +2305,6 @@ func HandleOpenAISpeechStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -2411,11 +2330,8 @@ func HandleOpenAISpeechStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2476,7 +2392,7 @@ func HandleOpenAITranscriptionRequest( logger schemas.Logger, ) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TranscriptionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } @@ -2484,13 +2400,13 @@ func HandleOpenAITranscriptionRequest( if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTranscriptionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2519,7 +2435,7 @@ func HandleOpenAITranscriptionRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form @@ -2546,7 +2462,7 @@ func HandleOpenAITranscriptionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -2556,7 +2472,7 @@ func HandleOpenAITranscriptionRequest( } if lpResult != nil { return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2576,7 +2492,12 @@ func HandleOpenAITranscriptionRequest( // Parse OpenAI's transcription response directly into BifrostTranscribe response := &schemas.BifrostTranscriptionResponse{} var rawResponse interface{} - if customResponseHandler != nil { + if request.Params != nil && schemas.IsPlainTextTranscriptionFormat(request.Params.ResponseFormat) { + response.Text = string(copiedResponseBody) + if sendBackRawResponse { + rawResponse = string(copiedResponseBody) + } + } else if customResponseHandler != nil { _, rawResponse, bifrostErr = customResponseHandler(copiedResponseBody, response, nil, false, sendBackRawResponse) } else { if err := sonic.Unmarshal(copiedResponseBody, response); err != nil { @@ -2590,7 +2511,7 @@ func HandleOpenAITranscriptionRequest( }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // TODO: add HandleProviderResponse here @@ -2598,7 +2519,7 @@ func HandleOpenAITranscriptionRequest( // Parse raw response for RawResponse field if sendBackRawResponse { if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err) } } } @@ -2608,9 +2529,6 @@ func HandleOpenAITranscriptionRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -2672,7 +2590,7 @@ func HandleOpenAITranscriptionStreamRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) if postRequestConverter != nil { @@ -2733,9 +2651,9 @@ func HandleOpenAITranscriptionStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -2745,7 +2663,7 @@ func HandleOpenAITranscriptionStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -2763,9 +2681,9 @@ func HandleOpenAITranscriptionStreamRequest( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2787,7 +2705,7 @@ func HandleOpenAITranscriptionStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2812,7 +2730,7 @@ func HandleOpenAITranscriptionStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2823,11 +2741,6 @@ func HandleOpenAITranscriptionStreamRequest( if customResponseHandler != nil { _, _, bifrostErr = customResponseHandler([]byte(jsonData), response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } if sendBackRawResponse { bifrostErr.ExtraFields.RawResponse = jsonData } @@ -2842,13 +2755,9 @@ func HandleOpenAITranscriptionStreamRequest( var bifrostErrVal schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErrVal); err == nil { if bifrostErrVal.Error != nil && bifrostErrVal.Error.Message != "" { - bifrostErrVal.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, nil, nil, false, sendBackRawResponse), responseChan, logger) + respBody := append([]byte(nil), resp.Body()...) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, body.Bytes(), respBody, false, sendBackRawResponse), responseChan, logger) return } } @@ -2872,11 +2781,8 @@ func HandleOpenAITranscriptionStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2966,20 +2872,20 @@ func HandleOpenAIImageGenerationRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageGenerationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2989,8 +2895,7 @@ func HandleOpenAIImageGenerationRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -3011,7 +2916,7 @@ func HandleOpenAIImageGenerationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -3021,7 +2926,7 @@ func HandleOpenAIImageGenerationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -3033,9 +2938,6 @@ func HandleOpenAIImageGenerationRequest( return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -3062,7 +2964,7 @@ func (provider *OpenAIProvider) ImageGenerationStream( request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } // Check if image generation stream is allowed for this provider @@ -3136,8 +3038,7 @@ func HandleOpenAIImageGenerationStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -3182,9 +3083,9 @@ func HandleOpenAIImageGenerationStreaming( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -3194,7 +3095,7 @@ func HandleOpenAIImageGenerationStreaming( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -3212,9 +3113,9 @@ func HandleOpenAIImageGenerationStreaming( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -3236,7 +3137,7 @@ func HandleOpenAIImageGenerationStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -3263,7 +3164,7 @@ func HandleOpenAIImageGenerationStreaming( if readErr != nil { if readErr != io.EOF { logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -3275,11 +3176,6 @@ func HandleOpenAIImageGenerationStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -3299,11 +3195,6 @@ func HandleOpenAIImageGenerationStreaming( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -3404,11 +3295,8 @@ func HandleOpenAIImageGenerationStreaming( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -3514,7 +3402,7 @@ func (provider *OpenAIProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3543,7 +3431,7 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3584,12 +3472,12 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -3607,8 +3495,6 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3624,7 +3510,7 @@ func (provider *OpenAIProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3694,10 +3580,10 @@ func HandleOpenAIVideoGenerationRequest( // Use centralized converter reqBody, err := ToOpenAIVideoGenerationRequest(request) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil) } // Create multipart form @@ -3723,12 +3609,12 @@ func HandleOpenAIVideoGenerationRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoGenerationRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -3754,9 +3640,6 @@ func HandleOpenAIVideoGenerationRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3821,12 +3704,12 @@ func HandleOpenAIVideoRetrieveRequest( if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRetrieveRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoGenerationResponse{} @@ -3868,8 +3751,6 @@ func HandleOpenAIVideoRetrieveRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3921,12 +3802,12 @@ func HandleOpenAIVideoDeleteRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDeleteRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse OpenAI's video response @@ -3940,8 +3821,6 @@ func HandleOpenAIVideoDeleteRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -4014,12 +3893,12 @@ func HandleOpenAIVideoListRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoListResponse{} @@ -4046,8 +3925,6 @@ func HandleOpenAIVideoListRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -4117,20 +3994,20 @@ func HandleOpenAICountTokensRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.CountTokensRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostCountTokensResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4139,9 +4016,7 @@ func HandleOpenAICountTokensRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -4162,7 +4037,7 @@ func HandleOpenAICountTokensRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.CountTokensRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4172,7 +4047,7 @@ func HandleOpenAICountTokensRequest( } if lpResult != nil { return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4185,9 +4060,6 @@ func HandleOpenAICountTokensRequest( } response.Model = request.Model - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4235,26 +4107,26 @@ func HandleOpenAIImageEditRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageEditRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageEditRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4301,7 +4173,7 @@ func HandleOpenAIImageEditRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4311,7 +4183,7 @@ func HandleOpenAIImageEditRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4320,9 +4192,6 @@ func HandleOpenAIImageEditRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4386,7 +4255,7 @@ func HandleOpenAIImageEditStreamRequest( ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { reqBody := ToOpenAIImageEditRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -4446,9 +4315,9 @@ func HandleOpenAIImageEditStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -4457,7 +4326,7 @@ func HandleOpenAIImageEditStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditStreamRequest, providerName, request.Model), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -4475,9 +4344,9 @@ func HandleOpenAIImageEditStreamRequest( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -4499,7 +4368,7 @@ func HandleOpenAIImageEditStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -4526,7 +4395,7 @@ func HandleOpenAIImageEditStreamRequest( if readErr != nil { if readErr != io.EOF { logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -4538,11 +4407,6 @@ func HandleOpenAIImageEditStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -4562,11 +4426,6 @@ func HandleOpenAIImageEditStreamRequest( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -4667,11 +4526,8 @@ func HandleOpenAIImageEditStreamRequest( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -4771,26 +4627,26 @@ func HandleOpenAIImageVariationRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageVariationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageVariationRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4836,7 +4692,7 @@ func HandleOpenAIImageVariationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageVariationRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4846,7 +4702,7 @@ func HandleOpenAIImageVariationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4855,9 +4711,6 @@ func HandleOpenAIImageVariationRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageVariationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4874,14 +4727,12 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Create multipart form data @@ -4890,16 +4741,16 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add expires_after fields if provided if request.ExpiresAfter != nil { if err := writer.WriteField("expires_after[anchor]", request.ExpiresAfter.Anchor); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err) } if err := writer.WriteField("expires_after[seconds]", fmt.Sprintf("%d", request.ExpiresAfter.Seconds)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err) } } @@ -4910,14 +4761,14 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -4947,13 +4798,13 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp OpenAIFileResponse @@ -4964,7 +4815,7 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, bifrostErr } - fileResponse := openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) + fileResponse := openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) fileResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) return fileResponse, nil } @@ -4983,7 +4834,7 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -4994,10 +4845,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -5047,12 +4894,12 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIFileListResponse @@ -5088,8 +4935,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -5110,7 +4955,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5145,7 +4990,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5155,7 +5000,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5186,7 +5031,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5221,7 +5066,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5231,7 +5076,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5252,9 +5097,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5281,7 +5124,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -5312,7 +5155,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5322,7 +5165,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5341,9 +5184,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -5360,10 +5201,10 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } if request.Input == nil || request.Input.Prompt == "" { - return nil, providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("prompt is required", nil) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -5371,8 +5212,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIVideoRemixRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -5410,12 +5250,12 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRemixRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Parse OpenAI's video response @@ -5433,9 +5273,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRemixRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } if sendBackRawResponse { @@ -5454,8 +5292,6 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -5463,7 +5299,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Convert inline requests to JSONL format jsonlData, err := ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" @@ -5482,12 +5318,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil) } // Validate that we have an endpoint if request.Endpoint == "" { - return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil) } // Create request @@ -5522,7 +5358,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -5538,12 +5374,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, ""), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var openAIResp OpenAIBatchResponse @@ -5552,7 +5388,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -5562,14 +5398,13 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -5580,10 +5415,6 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -5627,12 +5458,12 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIBatchListResponse @@ -5645,7 +5476,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -5659,9 +5490,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -5678,10 +5507,9 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, request.Provider) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5713,7 +5541,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5723,7 +5551,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5739,8 +5567,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -5754,10 +5581,9 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5789,7 +5615,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5799,7 +5625,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5822,9 +5648,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5864,11 +5688,9 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (this already iterates over keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -5879,7 +5701,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file - try each key @@ -5909,7 +5731,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5919,7 +5741,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5943,9 +5765,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5965,14 +5785,12 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.Name == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil) } // Build request body @@ -6008,7 +5826,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key jsonBody, err := providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create request @@ -6037,7 +5855,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, ParseOpenAIError(resp, schemas.ContainerCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6071,9 +5889,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6090,16 +5906,14 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // ContainerList lists containers via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } @@ -6113,7 +5927,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6124,10 +5938,6 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys Object: "list", Data: []schemas.ContainerObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - }, }, nil } @@ -6174,7 +5984,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.ContainerListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6209,9 +6019,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6232,20 +6040,18 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // ContainerRetrieve retrieves a specific container via OpenAI's API. func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerRetrieveRequest); err != nil { @@ -6280,7 +6086,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6320,9 +6126,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6343,20 +6147,18 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // ContainerDelete deletes a container via OpenAI's API. func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerDeleteRequest); err != nil { @@ -6391,7 +6193,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6419,9 +6221,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6450,14 +6250,12 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } // Create request @@ -6474,7 +6272,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle file upload (multipart only) if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil) } // Multipart file upload @@ -6484,13 +6282,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Add file part, err := writer.CreateFormFile("file", "file") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err) } if _, err = part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err) } req.Header.Set("Content-Type", writer.FormDataContentType()) req.SetBody(body.Bytes()) @@ -6508,13 +6306,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -6543,9 +6341,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6563,21 +6359,19 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // ContainerFileList lists files in a container via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6591,7 +6385,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6602,10 +6396,6 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k Object: "list", Data: []schemas.ContainerFileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - }, }, nil } @@ -6651,13 +6441,13 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k } if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var listResp struct { @@ -6689,9 +6479,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6712,13 +6500,11 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // ContainerFileRetrieve retrieves a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6727,15 +6513,15 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6763,7 +6549,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6772,7 +6558,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6807,9 +6593,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6830,13 +6614,11 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // ContainerFileContent retrieves the content of a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6845,15 +6627,15 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6881,7 +6663,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6898,7 +6680,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } content := append([]byte(nil), body...) @@ -6907,9 +6689,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileContentRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6933,13 +6713,11 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext // ContainerFileDelete deletes a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6948,15 +6726,15 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6984,7 +6762,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6993,7 +6771,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -7020,9 +6798,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -7091,7 +6867,7 @@ func (provider *OpenAIProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -7107,9 +6883,6 @@ func (provider *OpenAIProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -7177,9 +6950,9 @@ func (provider *OpenAIProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := make(map[string]string) @@ -7193,9 +6966,7 @@ func (provider *OpenAIProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Wrap reader with idle timeout to detect stalled streams. @@ -7204,11 +6975,7 @@ func (provider *OpenAIProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -7219,9 +6986,9 @@ func (provider *OpenAIProvider) PassthroughStream( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -7270,7 +7037,7 @@ func (provider *OpenAIProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/openai/openai_test.go b/core/providers/openai/openai_test.go index c37040ce62..d2173e1ac7 100644 --- a/core/providers/openai/openai_test.go +++ b/core/providers/openai/openai_test.go @@ -46,69 +46,69 @@ func TestOpenAI(t *testing.T) { ChatAudioModel: "gpt-4o-mini-audio-preview", PassthroughModel: "gpt-4o", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - WebSearchTool: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - FileBase64: true, - FileURL: true, - CompleteEnd2End: true, - SpeechSynthesis: true, - SpeechSynthesisStream: true, - Transcription: true, - TranscriptionStream: true, - Embedding: true, - Reasoning: true, - ListModels: true, - ImageGeneration: true, - ImageGenerationStream: true, - ImageEdit: true, - ImageEditStream: true, - ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations - VideoGeneration: false, // disabled for now because of long running operations - VideoRetrieve: false, - VideoRemix: false, - VideoDownload: false, - VideoList: false, - VideoDelete: false, - BatchCreate: true, - BatchList: true, - BatchRetrieve: true, - BatchCancel: true, - BatchResults: true, - FileUpload: true, - FileList: true, - FileRetrieve: true, - FileDelete: true, - FileContent: true, - FileBatchInput: true, - CountTokens: true, - ChatAudio: true, - StructuredOutputs: true, // Structured outputs with nullable enum support - ContainerCreate: true, - ContainerList: true, - ContainerRetrieve: true, - ContainerDelete: true, - ContainerFileCreate: true, - ContainerFileList: true, - ContainerFileRetrieve: true, - ContainerFileContent: true, - ContainerFileDelete: true, - PromptCaching: true, - PassthroughAPI: true, - WebSocketResponses: true, - Realtime: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + WebSearchTool: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + FileBase64: true, + FileURL: true, + CompleteEnd2End: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + Reasoning: true, + ListModels: true, + ImageGeneration: true, + ImageGenerationStream: true, + ImageEdit: true, + ImageEditStream: true, + ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations + VideoGeneration: false, // disabled for now because of long running operations + VideoRetrieve: false, + VideoRemix: false, + VideoDownload: false, + VideoList: false, + VideoDelete: false, + BatchCreate: true, + BatchList: true, + BatchRetrieve: true, + BatchCancel: true, + BatchResults: true, + FileUpload: true, + FileList: true, + FileRetrieve: true, + FileDelete: true, + FileContent: true, + FileBatchInput: true, + CountTokens: true, + ChatAudio: true, + StructuredOutputs: true, // Structured outputs with nullable enum support + ContainerCreate: true, + ContainerList: true, + ContainerRetrieve: true, + ContainerDelete: true, + ContainerFileCreate: true, + ContainerFileList: true, + ContainerFileRetrieve: true, + ContainerFileContent: true, + ContainerFileDelete: true, + PromptCaching: true, + PassthroughAPI: true, + WebSocketResponses: true, + Realtime: false, }, RealtimeModel: "gpt-4o-realtime-preview", } diff --git a/core/providers/openai/realtime.go b/core/providers/openai/realtime.go index b73db4ea24..8c88382297 100644 --- a/core/providers/openai/realtime.go +++ b/core/providers/openai/realtime.go @@ -1,13 +1,17 @@ package openai import ( + "bytes" "encoding/json" "fmt" + "mime/multipart" + "net/http" "net/url" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" ) // SupportsRealtimeAPI returns true since OpenAI natively supports the Realtime API. @@ -28,7 +32,6 @@ func (provider *OpenAIProvider) RealtimeWebSocketURL(key schemas.Key, model stri func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]string { headers := map[string]string{ "Authorization": "Bearer " + key.Value.GetValue(), - "OpenAI-Beta": "realtime=v1", } for k, v := range provider.networkConfig.ExtraHeaders { headers[k] = v @@ -36,6 +39,380 @@ func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]stri return headers } +// SupportsRealtimeWebRTC reports that OpenAI supports WebRTC SDP exchange. +func (provider *OpenAIProvider) SupportsRealtimeWebRTC() bool { + return true +} + +// ExchangeRealtimeWebRTCSDP performs the GA SDP exchange via multipart POST to /v1/realtime/calls. +func (provider *OpenAIProvider) ExchangeRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + model string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + path := "/v1/realtime/calls" + if session == nil && strings.TrimSpace(model) != "" { + path += "?model=" + url.QueryEscape(model) + } + return provider.exchangeWebRTCSDP(ctx, key, path, sdp, session) +} + +// ExchangeLegacyRealtimeWebRTCSDP performs the beta SDP exchange via multipart POST to /v1/realtime. +// Same multipart format but targets the legacy endpoint with model in the URL. +func (provider *OpenAIProvider) ExchangeLegacyRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + sdp string, + session json.RawMessage, + model string, +) (string, *schemas.BifrostError) { + return provider.exchangeWebRTCSDP(ctx, key, "/v1/realtime?model="+url.QueryEscape(model), sdp, session) +} + +// exchangeWebRTCSDP is the shared multipart SDP exchange implementation. +// Builds a multipart body with sdp + optional session, POSTs to the given path. +func (provider *OpenAIProvider) exchangeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + path string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + bodyBuf := &bytes.Buffer{} + writer := multipart.NewWriter(bodyBuf) + if err := writer.WriteField("sdp", sdp); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream SDP body", err) + } + if session != nil { + if err := writer.WriteField("session", string(session)); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream session body", err) + } + } + if err := writer.Close(); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to finalize upstream SDP body", err) + } + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, path, schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType(writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+key.Value.GetValue()) + for k, v := range provider.networkConfig.ExtraHeaders { + req.Header.Set(k, v) + } + if headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); headers != nil { + if agentsSDK := headers["x-openai-agents-sdk"]; agentsSDK != "" { + req.Header.Set("X-OpenAI-Agents-SDK", agentsSDK) + } + } + req.SetBody(bodyBuf.Bytes()) + + _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return "", bifrostErr + } + + answerBody := resp.Body() + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return "", provider.realtimeWebRTCUpstreamError(ctx, resp.StatusCode(), answerBody) + } + + return string(answerBody), nil +} + +func (provider *OpenAIProvider) realtimeWebRTCUpstreamError(ctx *schemas.BifrostContext, statusCode int, body []byte) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(fasthttp.StatusBadGateway), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("upstream_connection_error"), + Message: fmt.Sprintf("upstream realtime WebRTC handshake failed for %s", provider.GetProviderKey()), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: provider.GetProviderKey(), + }, + } + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostErr.ExtraFields.RawResponse = map[string]any{ + "status": statusCode, + "body": string(body), + } + } + return bifrostErr +} + +func newRealtimeWebRTCSDPError(status int, errorType, message string, err error) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + }, + } + if err != nil { + bifrostErr.Error.Error = err + } + return bifrostErr +} + +func (provider *OpenAIProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + if event == nil { + return false + } + switch event.Type { + case schemas.RTEventResponseCreate, schemas.RTEventInputAudioBufferCommitted: + return true + default: + return false + } +} + +func (provider *OpenAIProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *OpenAIProvider) RealtimeWebRTCDataChannelLabel() string { + return "oai-events" +} + +func (provider *OpenAIProvider) RealtimeWebSocketSubprotocol() string { + return "realtime" +} + +func (provider *OpenAIProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *OpenAIProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + switch eventType { + case schemas.RTEventResponseTextDelta, + schemas.RTEventResponseAudioTransDelta, + schemas.RealtimeEventType("response.output_text.delta"), + schemas.RealtimeEventType("response.output_audio_transcript.delta"): + return true + default: + return false + } +} + +// CreateRealtimeClientSecret mints an OpenAI Realtime client secret and returns +// the native OpenAI response body unchanged. +func (provider *OpenAIProvider) CreateRealtimeClientSecret( + ctx *schemas.BifrostContext, + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, + rawRequest json.RawMessage, +) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.RealtimeRequest); err != nil { + return nil, err + } + + normalizedBody, requestedModel, bifrostErr := normalizeRealtimeClientSecretRequest(rawRequest, provider.GetProviderKey(), endpointType) + if bifrostErr != nil { + return nil, bifrostErr + } + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, realtimeSessionUpstreamPath(endpointType), schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + for k, v := range provider.realtimeSessionHeaders(key, endpointType) { + req.Header.Set(k, v) + } + req.SetBody(normalizedBody) + + latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return nil, bifrostErr + } + + headers := providerUtils.ExtractProviderResponseHeaders(resp) + ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers) + + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return nil, ParseOpenAIError(resp) + } + + body, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) + } + for k := range headers { + if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { + delete(headers, k) + } + } + + out := &schemas.BifrostPassthroughResponse{ + StatusCode: resp.StatusCode(), + Headers: headers, + Body: body, + } + out.ExtraFields.Provider = provider.GetProviderKey() + out.ExtraFields.OriginalModelRequested = requestedModel + out.ExtraFields.RequestType = schemas.RealtimeRequest + out.ExtraFields.Latency = latency.Milliseconds() + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequestIfJSON(req, &out.ExtraFields) + } + + return out, nil +} + +func normalizeRealtimeClientSecretRequest( + rawRequest json.RawMessage, + defaultProvider schemas.ModelProvider, + endpointType schemas.RealtimeSessionEndpointType, +) ([]byte, string, *schemas.BifrostError) { + root, bifrostErr := schemas.ParseRealtimeClientSecretBody(rawRequest) + if bifrostErr != nil { + return nil, "", bifrostErr + } + + modelValue, bifrostErr := schemas.ExtractRealtimeClientSecretModel(root) + if bifrostErr != nil { + return nil, "", bifrostErr + } + providerKey, normalizedModel := schemas.ParseModelString(modelValue, defaultProvider) + if normalizedModel == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + if providerKey == "" { + providerKey = defaultProvider + } + if providerKey == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "unable to determine provider from model", nil) + } + + if endpointType == schemas.RealtimeSessionEndpointSessions { + return normalizeRealtimeSessionsRequest(root, normalizedModel) + } + + return normalizeRealtimeClientSecretsRequest(root, normalizedModel) +} + +func normalizeRealtimeClientSecretsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + session := map[string]json.RawMessage{} + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + session["model"] = modelJSON + if _, ok := session["type"]; !ok { + typeJSON, marshalErr := json.Marshal("realtime") + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session type", marshalErr) + } + session["type"] = typeJSON + } + delete(root, "model") + + sessionJSON, marshalErr := json.Marshal(session) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session", marshalErr) + } + root["session"] = sessionJSON + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func normalizeRealtimeSessionsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + session := map[string]json.RawMessage{} + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + for key, value := range session { + if _, exists := root[key]; !exists { + root[key] = value + } + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + root["model"] = modelJSON + delete(root, "session") + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func (provider *OpenAIProvider) realtimeSessionHeaders( + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, +) map[string]string { + headers := map[string]string{ + "Authorization": "Bearer " + key.Value.GetValue(), + } + if endpointType == schemas.RealtimeSessionEndpointSessions { + headers["OpenAI-Beta"] = "realtime=v1" + } + for k, v := range provider.networkConfig.ExtraHeaders { + headers[k] = v + } + return headers +} + +func realtimeSessionUpstreamPath(endpointType schemas.RealtimeSessionEndpointType) string { + if endpointType == schemas.RealtimeSessionEndpointSessions { + return "/v1/realtime/sessions" + } + return "/v1/realtime/client_secrets" +} + +func newRealtimeClientSecretError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: schemas.OpenAI, + }, + } +} + // openAIRealtimeEvent is the raw shape of an OpenAI Realtime protocol event. type openAIRealtimeEvent struct { Type string `json:"type"` @@ -44,15 +421,17 @@ type openAIRealtimeEvent struct { Conversation json.RawMessage `json:"conversation,omitempty"` Item json.RawMessage `json:"item,omitempty"` Response json.RawMessage `json:"response,omitempty"` + Part json.RawMessage `json:"part,omitempty"` Delta string `json:"delta,omitempty"` Audio string `json:"audio,omitempty"` Transcript string `json:"transcript,omitempty"` Text string `json:"text,omitempty"` Error json.RawMessage `json:"error,omitempty"` ItemID string `json:"item_id,omitempty"` - OutputIndex int `json:"output_index,omitempty"` - ContentIndex int `json:"content_index,omitempty"` + OutputIndex *int `json:"output_index,omitempty"` + ContentIndex *int `json:"content_index,omitempty"` ResponseID string `json:"response_id,omitempty"` + AudioEndMS *int `json:"audio_end_ms,omitempty"` PreviousItemID string `json:"previous_item_id,omitempty"` } @@ -105,6 +484,17 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes EventID: raw.EventID, RawData: providerEvent, } + setRealtimeExtraParam(event, "item_id", raw.ItemID) + setRealtimeExtraParam(event, "previous_item_id", raw.PreviousItemID) + setRealtimeExtraParam(event, "output_index", raw.OutputIndex) + setRealtimeExtraParam(event, "content_index", raw.ContentIndex) + setRealtimeExtraParam(event, "response_id", raw.ResponseID) + setRealtimeExtraParam(event, "audio_end_ms", raw.AudioEndMS) + setRealtimeExtraParam(event, "transcript", raw.Transcript) + setRealtimeExtraParam(event, "text", raw.Text) + setRealtimeExtraParam(event, "conversation", raw.Conversation) + setRealtimeExtraParam(event, "response", raw.Response) + setRealtimeExtraParam(event, "part", raw.Part) switch { case raw.Session != nil: @@ -123,8 +513,10 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes OutputAudioType: sess.OutputAudioType, Tools: sess.Tools, } + if extra := extractRealtimeNestedParams(raw.Session, "id", "model", "modalities", "instructions", "voice", "temperature", "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools"); len(extra) > 0 { + event.Session.ExtraParams = extra + } } - case raw.Item != nil: var item openAIRealtimeItem if err := json.Unmarshal(raw.Item, &item); err == nil { @@ -139,6 +531,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Arguments: item.Arguments, Output: item.Output, } + if extra := extractRealtimeNestedParams(raw.Item, "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output"); len(extra) > 0 { + event.Item.ExtraParams = extra + } } case raw.Error != nil: @@ -150,6 +545,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Message: rtErr.Message, Param: rtErr.Param, } + if extra := extractRealtimeNestedParams(raw.Error, "type", "code", "message", "param"); len(extra) > 0 { + event.Error.ExtraParams = extra + } } } @@ -159,8 +557,8 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Audio: raw.Audio, Transcript: raw.Transcript, ItemID: raw.ItemID, - OutputIdx: &raw.OutputIndex, - ContentIdx: &raw.ContentIndex, + OutputIdx: raw.OutputIndex, + ContentIdx: raw.ContentIndex, ResponseID: raw.ResponseID, } if raw.Delta != "" { @@ -175,19 +573,19 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes // ToProviderRealtimeEvent converts a unified Bifrost Realtime event back to OpenAI's native JSON. func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - out := map[string]interface{}{ "type": string(bifrostEvent.Type), } if bifrostEvent.EventID != "" { out["event_id"] = bifrostEvent.EventID } + mergeRealtimeExtraParams(out, bifrostEvent.ExtraParams) if bifrostEvent.Session != nil { sess := map[string]interface{}{} + if bifrostEvent.Session.ID != "" && bifrostEvent.Type != schemas.RTEventSessionUpdate { + sess["id"] = bifrostEvent.Session.ID + } if bifrostEvent.Session.Model != "" { sess["model"] = bifrostEvent.Session.Model } @@ -218,6 +616,7 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Session.Tools != nil { sess["tools"] = bifrostEvent.Session.Tools } + mergeRealtimeSessionExtraParams(sess, bifrostEvent.Session.ExtraParams, bifrostEvent.Type) out["session"] = sess } @@ -231,6 +630,9 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Role != "" { item["role"] = bifrostEvent.Item.Role } + if bifrostEvent.Item.Status != "" { + item["status"] = bifrostEvent.Item.Status + } if bifrostEvent.Item.Content != nil { item["content"] = bifrostEvent.Item.Content } @@ -246,9 +648,28 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Output != "" { item["output"] = bifrostEvent.Item.Output } + mergeRealtimeExtraParams(item, bifrostEvent.Item.ExtraParams) out["item"] = item } + if bifrostEvent.Error != nil { + rtErr := map[string]interface{}{} + if bifrostEvent.Error.Type != "" { + rtErr["type"] = bifrostEvent.Error.Type + } + if bifrostEvent.Error.Code != "" { + rtErr["code"] = bifrostEvent.Error.Code + } + if bifrostEvent.Error.Message != "" { + rtErr["message"] = bifrostEvent.Error.Message + } + if bifrostEvent.Error.Param != "" { + rtErr["param"] = bifrostEvent.Error.Param + } + mergeRealtimeExtraParams(rtErr, bifrostEvent.Error.ExtraParams) + out["error"] = rtErr + } + if bifrostEvent.Delta != nil { if bifrostEvent.Delta.Text != "" { out["delta"] = bifrostEvent.Delta.Text @@ -259,16 +680,16 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Delta.Transcript != "" { out["transcript"] = bifrostEvent.Delta.Transcript } - if bifrostEvent.Delta.ItemID != "" { + if bifrostEvent.Delta.ItemID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "item_id") { out["item_id"] = bifrostEvent.Delta.ItemID } - if bifrostEvent.Delta.OutputIdx != nil { + if bifrostEvent.Delta.OutputIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "output_index") { out["output_index"] = *bifrostEvent.Delta.OutputIdx } - if bifrostEvent.Delta.ContentIdx != nil { + if bifrostEvent.Delta.ContentIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "content_index") { out["content_index"] = *bifrostEvent.Delta.ContentIdx } - if bifrostEvent.Delta.ResponseID != "" { + if bifrostEvent.Delta.ResponseID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "response_id") { out["response_id"] = bifrostEvent.Delta.ResponseID } } @@ -276,11 +697,269 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi return providerUtils.MarshalSorted(out) } +func mergeRealtimeSessionExtraParams(out map[string]interface{}, params map[string]json.RawMessage, eventType schemas.RealtimeEventType) { + filtered := params + if eventType == schemas.RTEventSessionUpdate && len(params) > 0 { + filtered = make(map[string]json.RawMessage, len(params)) + for key, value := range params { + switch key { + case "id", "object", "expires_at", "client_secret": + continue + default: + filtered[key] = value + } + } + } + mergeRealtimeExtraParams(out, filtered) +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnUsage(terminalEventRaw []byte) *schemas.BifrostLLMUsage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil || parsed.Response.Usage == nil { + return nil + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: parsed.Response.Usage.InputTokens, + CompletionTokens: parsed.Response.Usage.OutputTokens, + TotalTokens: parsed.Response.Usage.TotalTokens, + } + + if parsed.Response.Usage.InputTokenDetails != nil { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens, + ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens, + CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens, + } + } + + if parsed.Response.Usage.OutputTokenDetails != nil { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens, + ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens, + ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens, + CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens, + NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries, + AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens, + RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens, + } + } + + return usage +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnOutput(terminalEventRaw []byte) *schemas.ChatMessage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil { + return nil + } + + content := extractOpenAIRealtimeResponseDoneAssistantText(parsed.Response.Output) + toolCalls := extractOpenAIRealtimeResponseDoneToolCalls(parsed.Response.Output) + if content == "" && len(toolCalls) == 0 { + return nil + } + + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if content != "" { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(content)} + } + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ToolCalls: toolCalls} + } + + return message +} + +type openAIRealtimeResponseDoneEnvelope struct { + Response struct { + Output []openAIRealtimeResponseDoneOutput `json:"output"` + Usage *openAIRealtimeResponseDoneUsage `json:"usage"` + } `json:"response"` +} + +type openAIRealtimeResponseDoneOutput struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + CallID string `json:"call_id"` + Arguments string `json:"arguments"` + Content []openAIRealtimeResponseDoneBlock `json:"content"` +} + +type openAIRealtimeResponseDoneBlock struct { + Text string `json:"text"` + Transcript string `json:"transcript"` + Refusal string `json:"refusal"` +} + +type openAIRealtimeResponseDoneUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails *openAIRealtimeResponseDoneInputTokenUsage `json:"input_token_details"` + OutputTokenDetails *openAIRealtimeResponseDoneOutputTokenUsage `json:"output_token_details"` +} + +type openAIRealtimeResponseDoneInputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` +} + +type openAIRealtimeResponseDoneOutputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + ImageTokens *int `json:"image_tokens"` + CitationTokens *int `json:"citation_tokens"` + NumSearchQueries *int `json:"num_search_queries"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +func extractOpenAIRealtimeResponseDoneAssistantText(outputs []openAIRealtimeResponseDoneOutput) string { + var sb strings.Builder + for _, output := range outputs { + if output.Type != "message" { + continue + } + for _, block := range output.Content { + switch { + case strings.TrimSpace(block.Text) != "": + sb.WriteString(block.Text) + case strings.TrimSpace(block.Transcript) != "": + sb.WriteString(block.Transcript) + case strings.TrimSpace(block.Refusal) != "": + sb.WriteString(block.Refusal) + } + } + } + return strings.TrimSpace(sb.String()) +} + +func extractOpenAIRealtimeResponseDoneToolCalls(outputs []openAIRealtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall { + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, output := range outputs { + if output.Type != "function_call" { + continue + } + + name := strings.TrimSpace(output.Name) + if name == "" { + continue + } + + toolType := "function" + id := strings.TrimSpace(output.CallID) + if id == "" { + id = strings.TrimSpace(output.ID) + } + + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(name), + Arguments: output.Arguments, + }, + } + if id != "" { + toolCall.ID = schemas.Ptr(id) + } + + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func setRealtimeExtraParam(event *schemas.BifrostRealtimeEvent, key string, value any) { + if event == nil || key == "" || value == nil { + return + } + + switch v := value.(type) { + case string: + if v == "" { + return + } + case *int: + if v == nil { + return + } + case json.RawMessage: + if len(v) == 0 || string(v) == "null" { + return + } + } + + raw, err := json.Marshal(value) + if err != nil { + return + } + if event.ExtraParams == nil { + event.ExtraParams = make(map[string]json.RawMessage) + } + event.ExtraParams[key] = raw +} + +func mergeRealtimeExtraParams(out map[string]interface{}, params map[string]json.RawMessage) { + for key, raw := range params { + if len(raw) == 0 { + continue + } + var value any + if err := json.Unmarshal(raw, &value); err != nil { + continue + } + out[key] = value + } +} + +func hasRealtimeExtraParam(params map[string]json.RawMessage, key string) bool { + if params == nil { + return false + } + raw, ok := params[key] + return ok && len(raw) > 0 +} + +func extractRealtimeNestedParams(raw json.RawMessage, knownKeys ...string) map[string]json.RawMessage { + if len(raw) == 0 { + return nil + } + root := map[string]json.RawMessage{} + if err := json.Unmarshal(raw, &root); err != nil { + return nil + } + for _, key := range knownKeys { + delete(root, key) + } + if len(root) == 0 { + return nil + } + return root +} + func isRealtimeDeltaEvent(eventType string) bool { switch eventType { case "response.text.delta", + "response.output_text.delta", "response.audio.delta", + "response.output_audio.delta", "response.audio_transcript.delta", + "response.output_audio_transcript.delta", "conversation.item.input_audio_transcription.delta": return true } diff --git a/core/providers/openai/realtime_test.go b/core/providers/openai/realtime_test.go new file mode 100644 index 0000000000..6b7f76f98f --- /dev/null +++ b/core/providers/openai/realtime_test.go @@ -0,0 +1,561 @@ +package openai + +import ( + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestNormalizeRealtimeClientSecretRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["model"]; ok { + t.Fatal("top-level model should be removed after normalization") + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeClientSecretRequestUsesDefaultProvider(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeSessionsRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointSessions, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["session"]; ok { + t.Fatal("legacy sessions endpoint should not forward nested session object") + } + if payload["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("model = %v, want %q", payload["model"], "gpt-4o-realtime-preview") + } + if payload["voice"] != "alloy" { + t.Fatalf("voice = %v, want %q", payload["voice"], "alloy") + } +} + +func TestToProviderRealtimeEventSerializesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + contentIndex, err := json.Marshal(0) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + audioEndMS, err := json.Marshal(640) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.truncate"), + ExtraParams: map[string]json.RawMessage{ + "item_id": json.RawMessage(`"item_123"`), + "content_index": contentIndex, + "audio_end_ms": audioEndMS, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "conversation.item.truncate" { + t.Fatalf("type = %v, want %q", payload["type"], "conversation.item.truncate") + } + if payload["item_id"] != "item_123" { + t.Fatalf("item_id = %v, want %q", payload["item_id"], "item_123") + } + if payload["content_index"] != float64(0) { + t.Fatalf("content_index = %v, want 0", payload["content_index"]) + } + if payload["audio_end_ms"] != float64(640) { + t.Fatalf("audio_end_ms = %v, want 640", payload["audio_end_ms"]) + } +} + +func TestToBifrostRealtimeEventParsesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + var itemID string + if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil { + t.Fatalf("json.Unmarshal(item_id) error = %v", err) + } + if itemID != "item_123" { + t.Fatalf("item_id = %q, want %q", itemID, "item_123") + } + var contentIndex int + if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil { + t.Fatalf("json.Unmarshal(content_index) error = %v", err) + } + if contentIndex != 0 { + t.Fatalf("content_index = %d, want 0", contentIndex) + } + var audioEndMS int + if err := json.Unmarshal(event.ExtraParams["audio_end_ms"], &audioEndMS); err != nil { + t.Fatalf("json.Unmarshal(audio_end_ms) error = %v", err) + } + if audioEndMS != 640 { + t.Fatalf("audio_end_ms = %d, want 640", audioEndMS) + } +} + +func TestToBifrostRealtimeEventParsesCompletedInputAudioTranscript(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.input_audio_transcription.completed","event_id":"evt_123","item_id":"item_123","content_index":0,"transcript":"Who are you?"}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var transcript string + if err := json.Unmarshal(event.ExtraParams["transcript"], &transcript); err != nil { + t.Fatalf("json.Unmarshal(transcript) error = %v", err) + } + if transcript != "Who are you?" { + t.Fatalf("transcript = %q, want %q", transcript, "Who are you?") + } +} + +func TestToBifrostRealtimeEventParsesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.output_text.delta", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "response_id":"resp_123", + "delta":"hello" + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + if event.Delta == nil || event.Delta.Text != "hello" { + t.Fatalf("Delta = %+v, want text=hello", event.Delta) + } +} + +func TestShouldStartRealtimeTurn(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + tests := []struct { + name string + event *schemas.BifrostRealtimeEvent + want bool + }{ + { + name: "response create starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseCreate}, + want: true, + }, + { + name: "audio buffer committed starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventInputAudioBufferCommitted}, + want: true, + }, + { + name: "response done does not start turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseDone}, + want: false, + }, + { + name: "nil event does not start turn", + event: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := provider.ShouldStartRealtimeTurn(tt.event); got != tt.want { + t.Fatalf("ShouldStartRealtimeTurn() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToProviderRealtimeEventSerializesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + outputIndex := 0 + contentIndex := 0 + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("response.output_text.delta"), + Delta: &schemas.RealtimeDelta{ + Text: "hello", + ItemID: "item_123", + OutputIdx: &outputIndex, + ContentIdx: &contentIndex, + ResponseID: "resp_123", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "response.output_text.delta" { + t.Fatalf("type = %v, want response.output_text.delta", payload["type"]) + } + if payload["delta"] != "hello" { + t.Fatalf("delta = %v, want hello", payload["delta"]) + } +} + +func TestToProviderRealtimeEventSerializesSessionID(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionCreated, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + session, ok := payload["session"].(map[string]any) + if !ok { + t.Fatalf("session = %T, want object", payload["session"]) + } + if session["id"] != "sess_123" { + t.Fatalf("session.id = %v, want sess_123", session["id"]) + } +} + +func TestToProviderRealtimeEventSerializesMessageItemStatus(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + content := json.RawMessage(`[{"type":"input_audio","transcript":"hello"}]`) + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.retrieved"), + Item: &schemas.RealtimeItem{ + ID: "item_123", + Type: "message", + Role: "user", + Status: "completed", + Content: content, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + item, ok := payload["item"].(map[string]any) + if !ok { + t.Fatalf("item = %T, want object", payload["item"]) + } + if item["status"] != "completed" { + t.Fatalf("item.status = %v, want completed", item["status"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.done", + "event_id":"evt_123", + "response":{ + "id":"resp_123", + "output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}] + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var response map[string]any + if err := json.Unmarshal(event.ExtraParams["response"], &response); err != nil { + t.Fatalf("json.Unmarshal(response) error = %v", err) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseDone, + ExtraParams: map[string]json.RawMessage{ + "response": json.RawMessage(`{"id":"resp_123","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + response, ok := payload["response"].(map[string]any) + if !ok { + t.Fatalf("response = %T, want object", payload["response"]) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.content_part.added", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "part":{ + "type":"text", + "text":"hello" + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var part map[string]any + if err := json.Unmarshal(event.ExtraParams["part"], &part); err != nil { + t.Fatalf("json.Unmarshal(part) error = %v", err) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseContentPartAdded, + ExtraParams: map[string]json.RawMessage{ + "part": json.RawMessage(`{"type":"text","text":"hello"}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + part, ok := payload["part"].(map[string]any) + if !ok { + t.Fatalf("part = %T, want object", payload["part"]) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestParseRealtimeEventPreservesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + event, err := schemas.ParseRealtimeEvent([]byte(`{ + "type":"session.update", + "session":{ + "type":"realtime", + "model":"gpt-4o-realtime-preview", + "output_modalities":["text"] + } + }`)) + if err != nil { + t.Fatalf("ParseRealtimeEvent() error = %v", err) + } + if event.Session == nil { + t.Fatal("expected session to be parsed") + } + var outputModalities []string + if err := json.Unmarshal(event.Session.ExtraParams["output_modalities"], &outputModalities); err != nil { + t.Fatalf("json.Unmarshal(output_modalities) error = %v", err) + } + if len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("output_modalities = %v, want [text]", outputModalities) + } +} + +func TestToProviderRealtimeEventSerializesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + Model: "gpt-4o-realtime-preview", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "output_modalities": json.RawMessage(`["text"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Type string `json:"type"` + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload.Type != "session.update" { + t.Fatalf("type = %q, want %q", payload.Type, "session.update") + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + outputModalities, ok := payload.Session["output_modalities"].([]any) + if !ok || len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("session.output_modalities = %v, want [text]", payload.Session["output_modalities"]) + } +} + +func TestToProviderRealtimeEventOmitsReadOnlySessionFieldsOnSessionUpdate(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "object": json.RawMessage(`"realtime.session"`), + "expires_at": json.RawMessage(`1774614381`), + "client_secret": json.RawMessage(`{"value":"secret"}`), + "modalities": json.RawMessage(`["text","audio"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + for _, key := range []string{"id", "object", "expires_at", "client_secret"} { + if _, ok := payload.Session[key]; ok { + t.Fatalf("session.%s unexpectedly present in session.update payload", key) + } + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + if payload.Session["model"] != "gpt-realtime" { + t.Fatalf("session.model = %v, want gpt-realtime", payload.Session["model"]) + } +} diff --git a/core/providers/openai/text_test.go b/core/providers/openai/text_test.go index b2dd53ee35..71c2f195a0 100644 --- a/core/providers/openai/text_test.go +++ b/core/providers/openai/text_test.go @@ -51,7 +51,6 @@ func TestToOpenAITextCompletionRequest_FireworksUsesCacheIsolation(t *testing.T) func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) diff --git a/core/providers/openai/transcription.go b/core/providers/openai/transcription.go index 8ab2305b05..8c2bf112a1 100644 --- a/core/providers/openai/transcription.go +++ b/core/providers/openai/transcription.go @@ -54,63 +54,63 @@ func ParseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiR } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return utils.NewBifrostOperationError("failed to create form file", err, providerName) + return utils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(openaiReq.File); err != nil { - return utils.NewBifrostOperationError("failed to write file data", err, providerName) + return utils.NewBifrostOperationError("failed to write file data", err) } // Add model field if err := writer.WriteField("model", openaiReq.Model); err != nil { - return utils.NewBifrostOperationError("failed to write model field", err, providerName) + return utils.NewBifrostOperationError("failed to write model field", err) } // Add optional fields if openaiReq.Language != nil { if err := writer.WriteField("language", *openaiReq.Language); err != nil { - return utils.NewBifrostOperationError("failed to write language field", err, providerName) + return utils.NewBifrostOperationError("failed to write language field", err) } } if openaiReq.Prompt != nil { if err := writer.WriteField("prompt", *openaiReq.Prompt); err != nil { - return utils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return utils.NewBifrostOperationError("failed to write prompt field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return utils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return utils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Temperature != nil { if err := writer.WriteField("temperature", fmt.Sprintf("%g", *openaiReq.Temperature)); err != nil { - return utils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return utils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range openaiReq.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } for _, include := range openaiReq.Include { if err := writer.WriteField("include[]", include); err != nil { - return utils.NewBifrostOperationError("failed to write include field", err, providerName) + return utils.NewBifrostOperationError("failed to write include field", err) } } if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return utils.NewBifrostOperationError("failed to write stream field", err, providerName) + return utils.NewBifrostOperationError("failed to write stream field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return utils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return utils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index a559fbac06..e2eab5245a 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -83,7 +83,7 @@ type OpenAIChatRequest struct { // PromptCacheIsolationKey is the Fireworks chat-completions field for cache isolation. PromptCacheIsolationKey *string `json:"prompt_cache_isolation_key,omitempty"` - //NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. + // NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. // This Field is populated only for such providers and is NOT to be used externally. MaxTokens *int `json:"max_tokens,omitempty"` diff --git a/core/providers/openai/videos.go b/core/providers/openai/videos.go index aa4052e029..512306b7c7 100644 --- a/core/providers/openai/videos.go +++ b/core/providers/openai/videos.go @@ -132,30 +132,30 @@ func (req *OpenAIVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *sc func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIVideoGenerationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add prompt field (required) if openaiReq.Prompt == "" { - return providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return providerUtils.NewBifrostOperationError("prompt is required", nil) } if err := writer.WriteField("prompt", openaiReq.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add optional model field if openaiReq.Model != "" { if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } } // Add optional seconds field if openaiReq.Seconds != nil { if err := writer.WriteField("seconds", *openaiReq.Seconds); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seconds field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seconds field", err) } } // Add optional size field if openaiReq.Size != "" { if err := writer.WriteField("size", openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } @@ -196,16 +196,16 @@ func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, opena "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err) } if _, err := part.Write(openaiReq.InputReference); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 799ea4b5ab..63ae8f48e4 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -4,7 +4,6 @@ package openrouter import ( "fmt" "net/http" - "slices" "strings" "time" @@ -95,12 +94,12 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // Check for auth errors (401, 403) statusCode := resp.StatusCode() if statusCode == fasthttp.StatusUnauthorized || statusCode == fasthttp.StatusForbidden { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } // Any 4xx/5xx error indicates the key might be invalid if statusCode >= 400 { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } return nil @@ -109,8 +108,6 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Validate the key first using /v1/auth/key (only during provider add/update). // OpenRouter's /v1/models doesn't require auth, so we need this extra check. shouldValidate := false @@ -158,7 +155,7 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, // Continue with empty response; allowed models will be backfilled below. modelsFetched = false } else { - bifrostErr := openai.ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := openai.ParseOpenAIError(resp) return nil, bifrostErr } } @@ -185,45 +182,62 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, } } - // Filter by key.Models - allowedModels := key.Models - blacklistedModels := key.BlacklistedModels + // OpenRouter model IDs in the API response do NOT include the "openrouter/" prefix + // (e.g. the API returns "openai/gpt-4", not "openrouter/openai/gpt-4"). + // Users may supply allowedModels / aliases with or without the prefix, so we + // normalize both by stripping it before feeding into the shared pipeline. providerPrefix := string(schemas.OpenRouter) + "/" + stripPrefix := func(s string) string { + if strings.HasPrefix(strings.ToLower(s), strings.ToLower(providerPrefix)) { + return s[len(providerPrefix):] + } + return s + } + + normalizedAllowed := make(schemas.WhiteList, 0, len(key.Models)) + for _, m := range key.Models { + normalizedAllowed = append(normalizedAllowed, stripPrefix(m)) + } + normalizedBlacklist := make(schemas.BlackList, 0, len(key.BlacklistedModels)) + for _, m := range key.BlacklistedModels { + normalizedBlacklist = append(normalizedBlacklist, stripPrefix(m)) + } + normalizedAliases := make(map[string]string, len(key.Aliases)) + for k, v := range key.Aliases { + normalizedAliases[stripPrefix(k)] = stripPrefix(v) + } + + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: normalizedAllowed, + BlacklistedModels: normalizedBlacklist, + Aliases: normalizedAliases, + Unfiltered: request.Unfiltered, + ProviderKey: schemas.OpenRouter, + MatchFns: providerUtils.DefaultMatchFns(), + } - if !request.Unfiltered && len(allowedModels) > 0 { + if pipeline.ShouldEarlyExit() { + openrouterResponse.Data = make([]schemas.Model, 0) + } else { + included := make(map[string]bool) filteredData := make([]schemas.Model, 0, len(openrouterResponse.Data)) - includedModels := make(map[string]bool) for i := range openrouterResponse.Data { + // rawID has no "openrouter/" prefix — e.g. "openai/gpt-4" rawID := openrouterResponse.Data[i].ID - if !(slices.Contains(allowedModels, rawID) || slices.Contains(allowedModels, providerPrefix+rawID)) { - continue - } - if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { - continue - } - openrouterResponse.Data[i].ID = providerPrefix + rawID - filteredData = append(filteredData, openrouterResponse.Data[i]) - includedModels[rawID] = true - } - // Backfill allowed models not in the API response - for _, allowedModel := range allowedModels { - rawID := strings.TrimPrefix(allowedModel, providerPrefix) - if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { - continue - } - if !includedModels[rawID] { - filteredData = append(filteredData, schemas.Model{ - ID: providerPrefix + rawID, - Name: schemas.Ptr(rawID), - }) - includedModels[rawID] = true // avoid duplicate backfill + for _, result := range pipeline.FilterModel(rawID) { + entry := openrouterResponse.Data[i] + entry.ID = providerPrefix + result.ResolvedID + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) + } else { + entry.Alias = nil + } + filteredData = append(filteredData, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + filteredData = append(filteredData, pipeline.BackfillModels(included)...) openrouterResponse.Data = filteredData - } else { - for i := range openrouterResponse.Data { - openrouterResponse.Data[i].ID = providerPrefix + openrouterResponse.Data[i].ID - } } openrouterResponse.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/openrouter/openrouter_test.go b/core/providers/openrouter/openrouter_test.go index d908d4d950..5a6a38ce3f 100644 --- a/core/providers/openrouter/openrouter_test.go +++ b/core/providers/openrouter/openrouter_test.go @@ -31,26 +31,26 @@ func TestOpenRouter(t *testing.T) { EmbeddingModel: "qwen/qwen3-embedding-4b", ReasoningModel: "openai/gpt-oss-120b", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: false, // OpenRouter's responses API is in Beta + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: false, // OpenRouter's responses API is in Beta MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, // OpenRouter's responses API is in Beta - ImageBase64: false, // OpenRouter's responses API is in Beta - MultipleImages: false, // OpenRouter's responses API is in Beta - FileBase64: true, - FileURL: true, - CompleteEnd2End: false, // OpenRouter's responses API is in Beta - Reasoning: true, - ListModels: true, - StructuredOutputs: true, // Structured outputs with nullable enum support - Embedding: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, // OpenRouter's responses API is in Beta + ImageBase64: false, // OpenRouter's responses API is in Beta + MultipleImages: false, // OpenRouter's responses API is in Beta + FileBase64: true, + FileURL: true, + CompleteEnd2End: false, // OpenRouter's responses API is in Beta + Reasoning: true, + ListModels: true, + StructuredOutputs: true, // Structured outputs with nullable enum support + Embedding: true, }, } diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index 0af0e4bba4..e03d891d38 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -145,9 +145,6 @@ func (provider *ParasailProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/perplexity/chat.go b/core/providers/perplexity/chat.go index f2bb5cb1b4..403e0d2dd9 100644 --- a/core/providers/perplexity/chat.go +++ b/core/providers/perplexity/chat.go @@ -284,8 +284,6 @@ func (response *PerplexityChatResponse) ToBifrostChatResponse(model string) *sch Object: response.Object, Created: response.Created, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Perplexity, }, SearchResults: response.SearchResults, Videos: response.Videos, diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index 8ceec1fd4c..f0b21ec21d 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -101,12 +101,12 @@ func (provider *PerplexityProvider) completeRequest(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) - return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, schemas.ChatCompletionRequest, provider.GetProviderKey(), model) + return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -141,8 +141,7 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToPerplexityChatCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -161,9 +160,6 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -223,9 +219,6 @@ func (provider *PerplexityProvider) Responses(ctx *schemas.BifrostContext, key s } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/perplexity/types.go b/core/providers/perplexity/types.go index feef9e0ccb..d5ad5c65f6 100644 --- a/core/providers/perplexity/types.go +++ b/core/providers/perplexity/types.go @@ -4,45 +4,45 @@ import "github.com/maximhq/bifrost/core/schemas" // PerplexityChatRequest represents a Perplexity chat completion request type PerplexityChatRequest struct { - Model string `json:"model"` // Required: Model to use for chat completion - Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects - SearchMode *string `json:"search_mode"` // Required: Search mode - ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) - MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate - Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature - TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling - LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference - SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter - ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images - ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions - SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter - SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter - SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter - LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter - LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter - TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling - Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty - ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response - DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search - EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier - WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options - MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response - Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model - ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool - ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls - Stop []string `json:"stop,omitempty"` // Optional: Stop sequences - LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities - TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities - NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results - NumImages *int `json:"num_images,omitempty"` // Optional: Number of images - SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter - ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter - ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter - SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search - StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode - ExtraParams map[string]interface{} `json:"-"` + Model string `json:"model"` // Required: Model to use for chat completion + Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects + SearchMode *string `json:"search_mode"` // Required: Search mode + ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) + MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate + Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature + TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling + LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference + SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter + ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images + ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions + SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter + SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter + SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter + LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter + LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter + TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling + Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty + ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response + DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search + EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier + WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options + MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response + Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model + ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls + Stop []string `json:"stop,omitempty"` // Optional: Stop sequences + LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities + TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities + NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results + NumImages *int `json:"num_images,omitempty"` // Optional: Number of images + SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter + ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter + ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter + SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search + StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode + ExtraParams map[string]interface{} `json:"-"` } // GetExtraParams implements the RequestBodyWithExtraParams interface diff --git a/core/providers/replicate/errors.go b/core/providers/replicate/errors.go index e7fc2051d0..1575d9ca77 100644 --- a/core/providers/replicate/errors.go +++ b/core/providers/replicate/errors.go @@ -15,9 +15,6 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: replicateErr.Detail, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } @@ -28,8 +25,5 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: string(body), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/replicate/files.go b/core/providers/replicate/files.go index cdd37a65c5..15ca0e13e8 100644 --- a/core/providers/replicate/files.go +++ b/core/providers/replicate/files.go @@ -30,8 +30,6 @@ func (r *ReplicateFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } @@ -67,8 +65,6 @@ func (r *ReplicateFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } diff --git a/core/providers/replicate/images.go b/core/providers/replicate/images.go index 4fa7d0dd81..0327b8b5fa 100644 --- a/core/providers/replicate/images.go +++ b/core/providers/replicate/images.go @@ -1,7 +1,6 @@ package replicate import ( - "strconv" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -28,49 +27,6 @@ var modelInputImageFieldMap = map[string]string{ "black-forest-labs/flux-krea-dev": "image", } -// convertSizeToReplicateFormat converts standard size format (e.g., "1024x1024") to Replicate format. -// Returns (aspectRatio, imageSize) where imageSize is "1k", "2k", "4k" and aspectRatio is one of: -// "1:1", "3:4", "4:3", "9:16", or "16:9". Returns empty strings if unparseable or ratio unrecognised. -func convertSizeToReplicateFormat(size string) (aspectRatio, imageSize string) { - parts := strings.Split(size, "x") - if len(parts) != 2 { - return "", "" - } - - width, err1 := strconv.Atoi(parts[0]) - height, err2 := strconv.Atoi(parts[1]) - if err1 != nil || err2 != nil { - return "", "" - } - - if width <= 0 || height <= 0 { - return "", "" - } - - if width <= 1024 && height <= 1024 { - imageSize = "1K" - } else if width <= 2048 && height <= 2048 { - imageSize = "2K" - } else if width <= 4096 && height <= 4096 { - imageSize = "4K" - } - - ratio := float64(width) / float64(height) - if ratio >= 0.99 && ratio <= 1.01 { - aspectRatio = "1:1" - } else if ratio >= 0.74 && ratio <= 0.76 { - aspectRatio = "3:4" - } else if ratio >= 1.32 && ratio <= 1.34 { - aspectRatio = "4:3" - } else if ratio >= 0.56 && ratio <= 0.57 { - aspectRatio = "9:16" - } else if ratio >= 1.77 && ratio <= 1.78 { - aspectRatio = "16:9" - } - - return aspectRatio, imageSize -} - // ToReplicateImageGenerationInput converts a Bifrost image generation request to Replicate prediction input func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationRequest) *ReplicatePredictionRequest { if bifrostReq == nil || bifrostReq.Input == nil { @@ -85,29 +41,6 @@ func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationR if bifrostReq.Params != nil { params := bifrostReq.Params - // Map InputImages to the appropriate field based on model - if len(params.InputImages) > 0 { - fieldName := getInputImageFieldName(bifrostReq.Model) - - switch fieldName { - case "image_prompt": - // For flux-1.1-pro variants: use first image as image_prompt - input.ImagePrompt = ¶ms.InputImages[0] - - case "input_image": - // For flux-kontext variants: add to ExtraParams as input_image - input.InputImage = ¶ms.InputImages[0] - - case "image": - // For flux-dev variants: use first image as image field - input.Image = ¶ms.InputImages[0] - - case "input_images": - // For all other models: use input_images array - input.InputImages = params.InputImages - } - } - if bifrostReq.Params.N != nil { input.NumberOfImages = bifrostReq.Params.N } @@ -117,7 +50,7 @@ func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationR } if params.Size != nil { - aspectRatio, imageSize := convertSizeToReplicateFormat(*params.Size) + aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size) _, hasExplicitResolution := params.ExtraParams["resolution"] if params.AspectRatio == nil && aspectRatio != "" { input.AspectRatio = &aspectRatio @@ -191,9 +124,6 @@ func ToBifrostImageGenerationResponse( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } @@ -294,7 +224,7 @@ func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *Rep input.Image = &images[0] case "input_images": - // For all other models: use input_images array + // For all other models: use input_images array (preserves multi-image support) input.InputImages = images } } @@ -309,7 +239,7 @@ func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *Rep } if params.Size != nil { - aspectRatio, imageSize := convertSizeToReplicateFormat(*params.Size) + aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size) _, hasExplicitAspectRatio := params.ExtraParams["aspect_ratio"] _, hasExplicitResolution := params.ExtraParams["resolution"] if aspectRatio != "" && !hasExplicitAspectRatio { diff --git a/core/providers/replicate/models.go b/core/providers/replicate/models.go index 3989628db1..6c0c14dbf7 100644 --- a/core/providers/replicate/models.go +++ b/core/providers/replicate/models.go @@ -1,61 +1,66 @@ package replicate import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// ToBifrostListModelsResponse converts Replicate models and deployments to a Bifrost list models response +// ToBifrostListModelsResponse converts Replicate deployments to a Bifrost list models response. +// Replicate model IDs are composite: "{owner}/{name}" (e.g. "stability-ai/stable-diffusion"). func ToBifrostListModelsResponse( deploymentsResponse *ReplicateDeploymentListResponse, providerKey schemas.ModelProvider, - allowedModels []string, - blacklistedModels []string, + allowedModels schemas.WhiteList, + blacklistedModels schemas.BlackList, + aliases map[string]string, unfiltered bool, ) *schemas.BifrostListModelsResponse { bifrostResponse := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } - includedModels := make(map[string]bool) - // Add deployments from /v1/deployments endpoint + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) + if deploymentsResponse != nil { for _, deployment := range deploymentsResponse.Results { + // Replicate model IDs are composite owner/name deploymentID := deployment.Owner + "/" + deployment.Name - modelName := schemas.Ptr(deployment.Name) var created *int64 - - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, deploymentID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, deploymentID) { - continue - } - - // Extract information from current release if available - if deployment.CurrentRelease != nil { - // Parse created timestamp - if deployment.CurrentRelease.CreatedAt != "" { - createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) - if createdTimestamp > 0 { - created = schemas.Ptr(createdTimestamp) - } + if deployment.CurrentRelease != nil && deployment.CurrentRelease.CreatedAt != "" { + createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) + if createdTimestamp > 0 { + created = schemas.Ptr(createdTimestamp) } } - bifrostModel := schemas.Model{ - ID: string(providerKey) + "/" + deploymentID, - Name: modelName, - Deployment: modelName, - OwnedBy: schemas.Ptr(deployment.Owner), - Created: created, + for _, result := range pipeline.FilterModel(deploymentID) { + bifrostModel := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(deployment.Name), + OwnedBy: schemas.Ptr(deployment.Owner), + Created: created, + } + if result.AliasValue != "" { + bifrostModel.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) - includedModels[deploymentID] = true } if deploymentsResponse.Next != nil { @@ -63,58 +68,8 @@ func ToBifrostListModelsResponse( } } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue - } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - } - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) return bifrostResponse } - -// ToReplicateListModelsResponse converts a Bifrost list models response to a Replicate list models response -// This is mainly used for testing and compatibility -func ToReplicateListModelsResponse(response *schemas.BifrostListModelsResponse) *ReplicateModelListResponse { - if response == nil { - return nil - } - - replicateResponse := &ReplicateModelListResponse{ - Results: make([]ReplicateModelResponse, 0, len(response.Data)), - } - - for _, model := range response.Data { - modelID := strings.TrimPrefix(model.ID, string(schemas.Replicate)+"/") - replicateModel := ReplicateModelResponse{ - URL: "https://replicate.com/" + modelID, - Name: modelID, - } - - if model.Description != nil { - replicateModel.Description = model.Description - } - - if model.OwnedBy != nil { - replicateModel.Owner = *model.OwnedBy - } - - replicateResponse.Results = append(replicateResponse.Results, replicateModel) - } - - // Set next page token if available - if response.NextPageToken != "" { - next := response.NextPageToken - replicateResponse.Next = &next - } - - return replicateResponse -} diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index cea010c6b9..aedf38471c 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -87,6 +87,12 @@ const ( pollingInterval = 2 * time.Second ) +// useDeploymentsEndpoint returns whether the key uses the deployments endpoint. +// Nil ReplicateKeyConfig is treated as false (default models/predictions behavior). +func useDeploymentsEndpoint(key schemas.Key) bool { + return key.ReplicateKeyConfig != nil && key.ReplicateKeyConfig.UseDeploymentsEndpoint +} + // createPrediction creates a new prediction on Replicate API // Supports both sync (with Prefer: wait header) and async modes // stripPrefer should be true for streaming requests to exclude the Prefer header @@ -149,7 +155,7 @@ func createPrediction( // Parse response body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, schemas.Replicate) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var prediction ReplicatePredictionResponse @@ -204,7 +210,7 @@ func getPrediction( // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, schemas.Replicate) + return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } prediction := &ReplicatePredictionResponse{} @@ -252,9 +258,7 @@ func pollPrediction( case <-pollCtx.Done(): return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError( schemas.ErrProviderRequestTimedOut, - fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds), - schemas.Replicate, - ) + fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds)) case <-ticker.C: prediction, rawResponse, providerResponseHeaders, err = getPrediction(pollCtx, client, predictionURL, key, logger, sendBackRawResponse) if err != nil { @@ -277,6 +281,17 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont client := provider.client extraHeaders := provider.networkConfig.ExtraHeaders + if !useDeploymentsEndpoint(key) { + return ToBifrostListModelsResponse( + &ReplicateDeploymentListResponse{}, + providerName, + key.Models, + key.BlacklistedModels, + key.Aliases, + request.Unfiltered, + ), nil + } + // Build deployments URL deploymentsURL := provider.buildRequestURL(ctx, "/v1/deployments", schemas.ListModelsRequest) @@ -335,9 +350,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont if err := sonic.Unmarshal(bodyCopy, &pageResponse); err != nil { return nil, providerUtils.NewBifrostOperationError( "failed to parse deployments response", - err, - schemas.Replicate, - ) + err) } // Append results from this page @@ -362,6 +375,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont providerName, key.Models, key.BlacklistedModels, + key.Aliases, request.Unfiltered, ) @@ -375,11 +389,10 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys } if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } startTime := time.Now() - providerName := provider.GetProviderKey() response, err := providerUtils.HandleMultipleListModelsRequests( ctx, @@ -393,8 +406,6 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys // Update metadata with total latency latency := time.Since(startTime) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil @@ -406,17 +417,11 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -431,7 +436,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.TextCompletionRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // create prediction @@ -480,10 +485,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostTextCompletionResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -503,11 +505,6 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -519,8 +516,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -532,7 +528,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.TextCompletionStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -556,9 +552,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -590,9 +584,9 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -637,7 +631,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -668,11 +662,8 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -706,14 +697,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -728,14 +712,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -751,10 +728,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont nil, // usage - not available in done event finishReason, chunkIndex, - schemas.TextCompletionStreamRequest, - provider.GetProviderKey(), - request.Model, - ) + schemas.TextCompletionStreamRequest) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -782,17 +756,11 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -807,7 +775,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.ChatCompletionRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // create prediction @@ -856,10 +824,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostChatResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -879,11 +844,6 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -895,8 +855,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -908,7 +867,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.ChatCompletionStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -932,9 +891,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -966,9 +923,9 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1013,7 +970,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -1051,11 +1008,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1089,14 +1043,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1111,14 +1058,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1144,11 +1084,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1176,17 +1113,11 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -1201,7 +1132,7 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ResponsesRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // create prediction @@ -1248,9 +1179,6 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc // Convert to Bifrost response response := prediction.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1268,24 +1196,18 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } // Enable streaming (using sjson to set field directly, preserving key order) if updatedData, err := providerUtils.SetJSONField(jsonData, "stream", true); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err) } else { jsonData = updatedData } @@ -1297,7 +1219,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ResponsesStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -1321,9 +1243,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1362,9 +1282,9 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(streamErr, fasthttp.ErrTimeout) || errors.Is(streamErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1397,9 +1317,9 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1411,10 +1331,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if reader == nil { bifrostErr := providerUtils.NewBifrostOperationError( - "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - provider.GetProviderKey(), - ) + "provider returned an empty response", + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse), responseChan, provider.logger) return @@ -1461,7 +1379,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()) + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1503,11 +1421,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } if sendBackRawRequest { @@ -1530,10 +1445,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1562,10 +1474,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1593,10 +1502,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1616,10 +1522,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, Delta: schemas.Ptr(currentEvent.Data), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1645,10 +1548,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, ItemID: schemas.Ptr(itemID), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1671,10 +1571,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1708,10 +1605,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1731,11 +1625,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CompletedAt: schemas.Ptr(int(time.Now().Unix())), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } @@ -1768,14 +1659,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream error: %s", errorMsg), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } + fmt.Errorf("stream error: %s", errorMsg)) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1836,19 +1720,13 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageGenerationInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1863,7 +1741,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageGenerationRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction with appropriate mode @@ -1915,10 +1793,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -1937,15 +1812,9 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -1954,8 +1823,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon replicateReq := ToReplicateImageGenerationInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1967,7 +1835,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon request.Model, provider.customProviderConfig, schemas.ImageGenerationStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction prediction, _, _, _, err := createPrediction( @@ -1988,10 +1856,16 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2023,9 +1897,9 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2072,7 +1946,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading SSE stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, provider.logger) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2117,11 +1992,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2155,36 +2027,24 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended with error")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2199,11 +2059,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon OutputFormat: lastOutputFormat, // Include output format CreatedAt: time.Now().Unix(), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2245,17 +2102,13 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon Error: &schemas.ErrorField{ Message: errorMsg, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Include accumulated raw responses in error if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2272,19 +2125,13 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageEditInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2299,7 +2146,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ImageEditRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction with appropriate mode @@ -2351,10 +2198,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2373,15 +2217,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -2390,8 +2228,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, replicateReq := ToReplicateImageEditInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2403,7 +2240,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageEditStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -2425,10 +2262,16 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2460,9 +2303,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2507,18 +2350,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, if errors.Is(readErr, context.Canceled) { return } - bifrostErr := providerUtils.NewBifrostOperationError( - "stream read error", - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("stream read error", readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2561,11 +2395,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2599,34 +2430,22 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended with error")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2641,11 +2460,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: lastOutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2673,18 +2489,12 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, bifrostErr := providerUtils.NewBifrostOperationError( "stream error", - fmt.Errorf("%s", errorData.Detail), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("%s", errorData.Detail)) if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2706,21 +2516,13 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - - providerName := provider.GetProviderKey() - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateVideoGenerationInput(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2732,7 +2534,7 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.VideoGenerationRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction with appropriate mode @@ -2761,13 +2563,10 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, if err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) + bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, schemas.Replicate) // Set extra fields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2787,7 +2586,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -2829,7 +2628,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -2841,12 +2640,10 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke bifrostResponse, convertErr := ToBifrostVideoGenerationResponse(&prediction) if convertErr != nil { - return nil, providerUtils.EnrichError(ctx, convertErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, convertErr, nil, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if sendBackRawResponse { @@ -2861,9 +2658,8 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve latest status/output first. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2877,19 +2673,17 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if videoResp.Videos[0].URL != nil { videoUrl = *videoResp.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2909,9 +2703,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2919,7 +2711,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -2933,8 +2725,6 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2990,7 +2780,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -3027,22 +2817,22 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s part, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } // Add filename field if provided if filename != "" { if err := writer.WriteField("filename", filename); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err) } } // Add type field (content type) if err := writer.WriteField("type", contentType); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write type field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write type field", err) } // Add metadata field if provided @@ -3051,24 +2841,24 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s if len(metadata) > 0 { metadataJSON, err := providerUtils.MarshalSorted(metadata) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } h := make(textproto.MIMEHeader) h.Set("Content-Disposition", `form-data; name="metadata"`) h.Set("Content-Type", "application/json") metadataPart, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err) } if _, err := metadataPart.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } } } } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3104,7 +2894,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var replicateResp ReplicateFileResponse @@ -3132,7 +2922,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper (Replicate uses cursor-based pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3143,10 +2933,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3195,7 +2981,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var replicateResp ReplicateFileListResponse @@ -3239,8 +3025,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: finalHasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3257,7 +3041,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3302,7 +3086,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3334,7 +3118,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3377,8 +3161,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3399,7 +3181,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3424,8 +3206,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, diff --git a/core/providers/replicate/replicate_test.go b/core/providers/replicate/replicate_test.go index c6f72cfded..a9179f963c 100644 --- a/core/providers/replicate/replicate_test.go +++ b/core/providers/replicate/replicate_test.go @@ -459,187 +459,6 @@ func TestBifrostToReplicateImageGenerationConversion(t *testing.T) { validate func(t *testing.T, result *replicate.ReplicatePredictionRequest) wantErr bool }{ - { - name: "Flux_1_1_Pro_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Flux 1.1 Pro should use ImagePrompt field - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.Image) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_1_1_Pro_Ultra_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro-ultra", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, - { - name: "Flux_Pro_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, - { - name: "Flux_Kontext_Pro_InputImage", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-kontext-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Kontext models should use InputImage field - assert.NotNil(t, result.Input.InputImage) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.InputImage) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.Image) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_Kontext_Max_InputImage", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-kontext-max", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.InputImage) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.InputImage) - }, - }, - { - name: "Flux_Dev_Image", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-dev", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Flux Dev should use Image field - assert.NotNil(t, result.Input.Image) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.Image) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_Fill_Pro_Image", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-fill-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.Image) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.Image) - }, - }, - { - name: "Other_Model_InputImages", - input: &schemas.BifrostImageGenerationRequest{ - Model: "stability-ai/sdxl", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input1.jpg", "https://example.com/input2.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Other models should use InputImages array - assert.NotNil(t, result.Input.InputImages) - assert.Len(t, result.Input.InputImages, 2) - assert.Equal(t, "https://example.com/input1.jpg", result.Input.InputImages[0]) - assert.Equal(t, "https://example.com/input2.jpg", result.Input.InputImages[1]) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.Image) - }, - }, - { - name: "Model_With_Version", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro:v1.0", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Should still match flux-1.1-pro and use ImagePrompt - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, { name: "AllParameters", input: &schemas.BifrostImageGenerationRequest{ diff --git a/core/providers/replicate/types.go b/core/providers/replicate/types.go index 98f84e613e..3ae88c0095 100644 --- a/core/providers/replicate/types.go +++ b/core/providers/replicate/types.go @@ -313,28 +313,28 @@ type ReplicatePredictionListResponse struct { // ReplicateModelResponse represents a model response type ReplicateModelResponse struct { - URL string `json:"url"` // Model API URL - Owner string `json:"owner"` // Owner username or org name - Name string `json:"name"` // Model name - Description *string `json:"description,omitempty"` // Model description - Visibility string `json:"visibility"` // "public" or "private" - GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL - PaperURL *string `json:"paper_url,omitempty"` // Research paper URL - LicenseURL *string `json:"license_url,omitempty"` // License URL - RunCount *int `json:"run_count,omitempty"` // Number of times run - CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL - DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) - LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details - FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details + URL string `json:"url"` // Model API URL + Owner string `json:"owner"` // Owner username or org name + Name string `json:"name"` // Model name + Description *string `json:"description,omitempty"` // Model description + Visibility string `json:"visibility"` // "public" or "private" + GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL + PaperURL *string `json:"paper_url,omitempty"` // Research paper URL + LicenseURL *string `json:"license_url,omitempty"` // License URL + RunCount *int `json:"run_count,omitempty"` // Number of times run + CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL + DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) + LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details + FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details } // ReplicateModelVersion represents a model version type ReplicateModelVersion struct { - ID string `json:"id"` // Version ID - CreatedAt string `json:"created_at"` // ISO 8601 timestamp - CogVersion *string `json:"cog_version,omitempty"` // Cog version used - OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) - DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID + ID string `json:"id"` // Version ID + CreatedAt string `json:"created_at"` // ISO 8601 timestamp + CogVersion *string `json:"cog_version,omitempty"` // Cog version used + OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) + DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID } // ReplicateModelListResponse represents a paginated list of models diff --git a/core/providers/replicate/utils.go b/core/providers/replicate/utils.go index 3279b0a847..1d88337539 100644 --- a/core/providers/replicate/utils.go +++ b/core/providers/replicate/utils.go @@ -31,17 +31,13 @@ func checkForErrorStatus(prediction *ReplicatePredictionResponse) *schemas.Bifro } return providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("%s", errorMsg), - schemas.Replicate, - ) + fmt.Errorf("%s", errorMsg)) } if prediction.Status == ReplicatePredictionStatusCanceled { return providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("prediction was canceled"), - schemas.Replicate, - ) + fmt.Errorf("prediction was canceled")) } return nil @@ -126,9 +122,9 @@ func listenToReplicateStreamURL( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -178,24 +174,12 @@ func isVersionID(s string) bool { return versionIDPattern.MatchString(s) } -// resolveDeploymentModel checks if the model maps to a deployment. -// Returns the resolved model and whether it is a deployment. -func resolveDeploymentModel(model string, key schemas.Key) (string, bool) { - if key.ReplicateKeyConfig == nil || key.ReplicateKeyConfig.Deployments == nil { - return model, false - } - if deployment, ok := key.ReplicateKeyConfig.Deployments[model]; ok && deployment != "" { - return deployment, true - } - return model, false -} - // buildPredictionURL builds the appropriate URL for creating a prediction // Returns the URL for the appropriate prediction endpoint. -func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, isDeployment bool) string { +func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, useDeploymentsEndpoint bool) string { var defaultPath string - if isDeployment { + if useDeploymentsEndpoint { defaultPath = "/v1/deployments/" + model + "/predictions" } else if isVersionID(model) { // If model is a version ID, use base predictions endpoint diff --git a/core/providers/replicate/videos.go b/core/providers/replicate/videos.go index b6dadaab55..3a277d067d 100644 --- a/core/providers/replicate/videos.go +++ b/core/providers/replicate/videos.go @@ -87,9 +87,6 @@ func ToBifrostVideoGenerationResponse(prediction *ReplicatePredictionResponse) ( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/runway/errors.go b/core/providers/runway/errors.go index a64f8ffc60..d9259e825f 100644 --- a/core/providers/runway/errors.go +++ b/core/providers/runway/errors.go @@ -9,7 +9,7 @@ import ( ) // parseRunwayError parses Runway API error responses and converts them to BifrostError. -func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseRunwayError(resp *fasthttp.Response) *schemas.BifrostError { // Parse as RunwayAPIError var errorResp RunwayAPIError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -34,12 +34,5 @@ func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Message = strings.TrimRight(bifrostErr.Error.Message, "\n") } - // Set metadata - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } - return bifrostErr } diff --git a/core/providers/runway/runway.go b/core/providers/runway/runway.go index 7bcce918c9..6bb5e31e32 100644 --- a/core/providers/runway/runway.go +++ b/core/providers/runway/runway.go @@ -170,8 +170,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostReq, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToRunwayVideoGenerationRequest(bifrostReq) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -210,17 +209,14 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -237,10 +233,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key Object: "video", Status: schemas.VideoStatusQueued, ExtraFields: schemas.BifrostResponseExtraFields{ - Latency: latency.Milliseconds(), - Provider: providerName, - ModelRequested: model, - RequestType: schemas.VideoGenerationRequest, + Latency: latency.Milliseconds(), }, } @@ -287,16 +280,14 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -314,8 +305,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if sendBackRawRequest { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -329,7 +318,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // VideoDownload retrieves a video from Runway's API. func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() // Retrieve task status to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ Provider: request.Provider, @@ -343,20 +331,21 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if taskDetails.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", taskDetails.Status), - nil, - providerName, - ) + nil) } if len(taskDetails.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if taskDetails.Videos[0].URL != nil { videoUrl = *taskDetails.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + // Download video from Runway's URL req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -372,14 +361,13 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } // Get content and content type body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -394,8 +382,6 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -407,7 +393,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("task_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("task_id is required", nil) } taskID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -439,10 +425,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch // Handle error response - Runway returns 204 No Content on success if resp.StatusCode() != fasthttp.StatusNoContent { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoDeleteRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Build response - Runway returns empty body on 204 @@ -453,8 +436,6 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch } response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.VideoDeleteRequest return response, nil } diff --git a/core/providers/runway/videos.go b/core/providers/runway/videos.go index 49b0cec237..809a8a1038 100644 --- a/core/providers/runway/videos.go +++ b/core/providers/runway/videos.go @@ -121,7 +121,7 @@ func ToRunwayVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Runway task details to Bifrost video generation response format. func ToBifrostVideoGenerationResponse(taskDetails *RunwayTaskDetailsResponse) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if taskDetails == nil { - return nil, providerUtils.NewBifrostOperationError("task details is nil", nil, schemas.Runway) + return nil, providerUtils.NewBifrostOperationError("task details is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index ce5d3d936a..5b07356851 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -3,7 +3,6 @@ package sgl import ( - "fmt" "strings" "time" @@ -50,11 +49,7 @@ func NewSGLProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*SGL client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") - // BaseURL is required for SGLang - if config.NetworkConfig.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for sgl provider") - } - + // BaseURL is optional when keys have sgl_key_config with per-key URLs return &SGLProvider{ logger: logger, client: client, @@ -69,27 +64,40 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { return schemas.SGL } -// ListModels performs a list models request to SGL's API. -func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - return openai.HandleOpenAIListModelsRequest( +// listModelsByKey performs a list models request for a single SGL key, +// resolving the per-key URL so each backend is queried individually. +func (provider *SGLProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return openai.ListModelsByKey( ctx, provider.client, - request, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"), - keys, + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), + key, + request.Unfiltered, provider.networkConfig.ExtraHeaders, - schemas.SGL, + provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), ) } -// TextCompletion is not supported by the SGL provider. +// ListModels performs a list models request to SGL's API. +// Requests are made concurrently per key so that each backend is queried +// with its own URL (from sgl_key_config). +func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return providerUtils.HandleMultipleListModelsRequests( + ctx, + keys, + request, + provider.listModelsByKey, + ) +} + +// TextCompletion performs a text completion request to the SGL API. func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -109,7 +117,7 @@ func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, p return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/completions", + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -129,7 +137,7 @@ func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key sch return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -151,7 +159,7 @@ func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, p return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/chat/completions", + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -176,9 +184,6 @@ func (provider *SGLProvider) Responses(ctx *schemas.BifrostContext, key schemas. } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -194,12 +199,12 @@ func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHo ) } -// Embedding is not supported by the SGL provider. +// Embedding performs an embedding request to the SGL API. func (provider *SGLProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, diff --git a/core/providers/sgl/sgl_test.go b/core/providers/sgl/sgl_test.go index 11447f58b4..20236182fc 100644 --- a/core/providers/sgl/sgl_test.go +++ b/core/providers/sgl/sgl_test.go @@ -29,22 +29,22 @@ func TestSGL(t *testing.T) { TextModel: "qwen/qwen2.5-0.5b-instruct", EmbeddingModel: "Alibaba-NLP/gte-Qwen2-1.5B-instruct", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + Embedding: true, + ListModels: true, }, } diff --git a/core/providers/utils/images.go b/core/providers/utils/images.go new file mode 100644 index 0000000000..12bc01c8ba --- /dev/null +++ b/core/providers/utils/images.go @@ -0,0 +1,50 @@ +package utils + +import ( + "strconv" + "strings" +) + +// ConvertSizeToAspectRatioAndResolution converts a standard size string (e.g., "1024x1024") +// to an aspect ratio and image size tier. +// aspectRatio is one of "1:1", "3:4", "4:3", "9:16", "16:9" (empty if unrecognised). +// imageSize is one of "1K", "2K", "4K" (empty if out of range). +func ConvertSizeToAspectRatioAndResolution(size string) (aspectRatio, imageSize string) { + parts := strings.Split(size, "x") + if len(parts) != 2 { + return "", "" + } + + width, err1 := strconv.Atoi(parts[0]) + height, err2 := strconv.Atoi(parts[1]) + if err1 != nil || err2 != nil { + return "", "" + } + + if width <= 0 || height <= 0 { + return "", "" + } + + if width <= 1024 && height <= 1024 { + imageSize = "1K" + } else if width <= 2048 && height <= 2048 { + imageSize = "2K" + } else if width <= 4096 && height <= 4096 { + imageSize = "4K" + } + + ratio := float64(width) / float64(height) + if ratio >= 0.99 && ratio <= 1.01 { + aspectRatio = "1:1" + } else if ratio >= 0.74 && ratio <= 0.76 { + aspectRatio = "3:4" + } else if ratio >= 1.32 && ratio <= 1.34 { + aspectRatio = "4:3" + } else if ratio >= 0.56 && ratio <= 0.57 { + aspectRatio = "9:16" + } else if ratio >= 1.77 && ratio <= 1.78 { + aspectRatio = "16:9" + } + + return aspectRatio, imageSize +} diff --git a/core/providers/utils/large_response.go b/core/providers/utils/large_response.go index a7e0e7bf36..e62d375c9a 100644 --- a/core/providers/utils/large_response.go +++ b/core/providers/utils/large_response.go @@ -116,7 +116,6 @@ func MaterializeStreamErrorBody(ctx *schemas.BifrostContext, resp *fasthttp.Resp func FinalizeResponseWithLargeDetection( ctx *schemas.BifrostContext, resp *fasthttp.Response, - providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, bool, *schemas.BifrostError) { responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) @@ -125,7 +124,7 @@ func FinalizeResponseWithLargeDetection( if responseThreshold <= 0 { body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Copy body before caller releases resp return append([]byte(nil), body...), false, nil @@ -142,14 +141,14 @@ func FinalizeResponseWithLargeDetection( } bodyBytes, readErr := io.ReadAll(reader) if readErr != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } return bodyBytes, false, nil } // No stream — buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -169,7 +168,7 @@ func FinalizeResponseWithLargeDetection( bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, responseThreshold+1)) if readErr != nil { releaseGzip() - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } if int64(len(bodyBytes)) <= responseThreshold { releaseGzip() @@ -195,7 +194,7 @@ func FinalizeResponseWithLargeDetection( // No stream — buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -206,11 +205,11 @@ func FinalizeResponseWithLargeDetection( if bodyStream == nil { // No stream available — fall back to buffered read if logger != nil { - logger.Warn("large-response fallback to buffered path: provider=%s content_length=%d threshold=%d body_stream_nil=true", providerName, contentLength, responseThreshold) + logger.Warn("large-response fallback to buffered path: content_length=%d threshold=%d body_stream_nil=true", contentLength, responseThreshold) } body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -232,7 +231,7 @@ func FinalizeResponseWithLargeDetection( if wasGzip { ReleaseGzipReader(gz) } - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } prefetchBuf = prefetchBuf[:n] diff --git a/core/providers/utils/make_request_test.go b/core/providers/utils/make_request_test.go index ce1610d7bb..ec2bf771bc 100644 --- a/core/providers/utils/make_request_test.go +++ b/core/providers/utils/make_request_test.go @@ -295,7 +295,7 @@ func TestMakeRequestWithContext_ConcurrentRequestsWithCancellation(t *testing.T) } func TestNewBifrostTimeoutError(t *testing.T) { - err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded, "openai") + err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded) if !err.IsBifrostError { t.Fatal("expected IsBifrostError to be true") diff --git a/core/providers/utils/models.go b/core/providers/utils/models.go new file mode 100644 index 0000000000..f1b3d0351b --- /dev/null +++ b/core/providers/utils/models.go @@ -0,0 +1,356 @@ +// Package utils — list_models.go +// Centralised pipeline for filtering and backfilling models in ListModels responses. +// +// Every provider's ToBifrostListModelsResponse follows the same logical steps: +// 1. Resolve each API model's name (alias lookup → alias key; else raw model ID) +// 2. Filter (allowlist + blacklist check on the resolved name) +// 3. Backfill entries that were not returned by the API but should appear in output +// +// Providers plug in custom MatchFns to extend the default matching behaviour. +// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns. +package utils + +import ( + "sort" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// ToDisplayName converts a raw model ID or alias key into a human-readable display name. +// Splits on "-" or "_", title-cases each word, and joins with spaces. +// +// "gemini-pro" → "Gemini Pro" +// "claude_3_opus" → "Claude 3 Opus" +// "gpt-4-turbo" → "Gpt 4 Turbo" +func ToDisplayName(id string) string { + caser := cases.Title(language.English) + parts := strings.FieldsFunc(id, func(r rune) bool { + return r == '-' || r == '_' + }) + if len(parts) == 0 { + return "" + } + for i, part := range parts { + if part != "" { + parts[i] = caser.String(strings.ToLower(part)) + } + } + return strings.Join(parts, " ") +} + +// MatchFn reports whether two model ID strings should be treated as equivalent. +// Functions are applied in order during every comparison — the first one that +// returns true short-circuits the rest. +// +// Example built-in fns (see DefaultMatchFns): +// +// exactMatch("gpt-4", "gpt-4") → true +// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") → true +type MatchFn func(a, b string) bool + +// DefaultMatchFns returns the standard matching functions used by most providers. +// Currently only performs case-insensitive exact matching. +// +// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" ≈ "claude-3-5-sonnet") +// is intentionally excluded — users should use aliases for explicit version-to-base-name mapping. +// It can be appended here if fuzzy base-model matching is ever needed globally. +func DefaultMatchFns() []MatchFn { + return []MatchFn{ + func(a, b string) bool { return strings.EqualFold(a, b) }, + } +} + +// matches reports whether a and b are considered equal by any of the provided fns. +// Returns true on the first fn that returns true. +func matches(a, b string, fns []MatchFn) bool { + for _, fn := range fns { + if fn(a, b) { + return true + } + } + return false +} + +// FilterResult is the outcome of running Pipeline.FilterModel for a single model +// from the provider's API response. Each returned result represents one alias +// entry (or the raw model ID when no alias matched) that passed all filters. +type FilterResult struct { + // ResolvedID is the user-facing model name to use as the ID suffix. + // If the model matched an alias VALUE, this is the alias KEY. + // Otherwise this is the original model ID from the API response. + // + // Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"} + // → ResolvedID = "my-gpt4" + // Example: API returns "gpt-3.5-turbo", no alias match + // → ResolvedID = "gpt-3.5-turbo" + ResolvedID string + + // AliasValue is the provider-specific model ID when the model was matched + // via an alias. Set as the model.Alias field so callers know the underlying ID. + // Empty when the model was matched directly (no alias involved). + // + // Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched + // → AliasValue = "gpt-4-turbo" + AliasValue string +} + +// Pipeline holds all the context needed to filter and backfill models in a +// single ListModels response. Construct one per ToBifrostListModelsResponse call +// and use its methods instead of passing params + matchFns to every function. +// +// pipeline := &providerUtils.ListModelsPipeline{ +// AllowedModels: key.Models, +// BlacklistedModels: key.BlacklistedModels, +// Aliases: key.Aliases, +// Unfiltered: request.Unfiltered, +// ProviderKey: schemas.OpenAI, +// MatchFns: providerUtils.DefaultMatchFns(), +// } +// if pipeline.ShouldEarlyExit() { return empty } +// result := pipeline.FilterModel(model.ID) +// pipeline.BackfillModels(included) +type ListModelsPipeline struct { + AllowedModels schemas.WhiteList + BlacklistedModels schemas.BlackList + // Aliases maps user-facing alias keys to provider-specific model IDs. + // e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"} + Aliases map[string]string + Unfiltered bool + ProviderKey schemas.ModelProvider + // MatchFns is the ordered list of equivalence functions used for every + // model ID comparison. Use DefaultMatchFns() for standard behaviour; + // providers may append additional fns (e.g. Bedrock's region-prefix remover). + MatchFns []MatchFn +} + +// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately +// return an empty response without processing any models. +// +// Returns true when: +// - not unfiltered AND allowlist is empty AND no aliases configured +// (there is nothing to match against — all models would be filtered out anyway) +// - not unfiltered AND blacklist blocks everything +// +// Note: allowlist empty + aliases present → do NOT early exit. +// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels). +func (p *ListModelsPipeline) ShouldEarlyExit() bool { + if p.Unfiltered { + return false + } + if p.BlacklistedModels.IsBlockAll() { + return true + } + if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 { + return true + } + return false +} + +// aliasMatch holds a single alias key/value pair returned by resolveModelID. +type aliasMatch struct { + key string + value string +} + +// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns, +// plus the raw model ID itself as an additional entry so both the alias key and the original model +// name appear in the list-models output. +// Results are sorted by alias key (case-insensitive) for deterministic ordering. +// +// If one or more aliases match → returns one aliasMatch per matching alias key, plus the raw ID. +// +// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"} +// → [{key:"gpt-4-turbo", value:""}, {key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}] +// +// If no alias matches → returns a single entry with the original model ID and no alias value. +// +// Example: modelID="gpt-3.5-turbo", no alias match +// → [{key:"gpt-3.5-turbo", value:""}] +func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch { + var candidates []aliasMatch + for aliasKey, providerID := range p.Aliases { + if matches(modelID, providerID, p.MatchFns) { + candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID}) + } + } + if len(candidates) == 0 { + return []aliasMatch{{key: modelID, value: ""}} + } + // Also include the raw model ID so both the alias key and the original name appear in output. + candidates = append(candidates, aliasMatch{key: modelID, value: ""}) + sort.Slice(candidates, func(i, j int) bool { + return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key) + }) + return candidates +} + +// FilterModel applies the full filter pipeline for a single model from the API response. +// +// Steps: +// 1. Resolve name — check alias VALUES for a match (uses MatchFns). +// If matched: resolvedName = alias KEY, aliasValue = provider ID. +// If not matched: resolvedName = original modelID, aliasValue = "". +// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard): +// Skip if resolvedName is not in AllowedModels. +// 3. Blacklist check (always): +// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything. +// 4. Return one FilterResult per passing candidate. +// +// An empty slice means the model should be skipped entirely. +// When multiple aliases map to the same provider model ID, each alias that passes +// the filters produces its own FilterResult entry. +// +// Examples: +// +// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") → [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// FilterModel("gpt-3.5") → [] (not in allowlist) +// +// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") → [{ResolvedID:"gpt-4-turbo", AliasValue:""}, +// {ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"}, +// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// +// allowedModels=["gpt-3.5"], aliases={}, blacklist=[] +// FilterModel("gpt-3.5") → [{ResolvedID:"gpt-3.5", AliasValue:""}] +// FilterModel("gpt-4") → [] +func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult { + // Step 1: resolve name — collect all alias matches (or the raw ID if none match). + candidates := p.resolveModelID(modelID) + + var results []FilterResult + for _, candidate := range candidates { + resolvedName := candidate.key + + // Step 2: allowlist check. + // IsRestricted() is true for both an explicit list AND an empty list (deny-all). + // Only a wildcard allowlist marker bypasses this check (pass-through). + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + allowed := false + for _, entry := range p.AllowedModels { + if matches(resolvedName, entry, p.MatchFns) { + allowed = true + break + } + } + if !allowed { + continue + } + } + + // Step 3: blacklist check — blacklist always wins regardless of allowlist or aliases. + if !p.Unfiltered { + blacklisted := false + for _, entry := range p.BlacklistedModels { + if matches(resolvedName, entry, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + } + + results = append(results, FilterResult{ + ResolvedID: resolvedName, + AliasValue: candidate.value, + }) + } + return results +} + +// BackfillModels adds model entries that were configured by the caller but not +// returned by the provider's API response (or not matched during filtering). +// +// The `included` map tracks model IDs (lowercased) already added during the +// filter pass, used to avoid duplicates. +// +// Two cases depending on whether the allowlist is restricted: +// +// Case A — allowlist restricted (caller specified explicit model names): +// +// Add each allowlist entry that is not yet in `included`, skip if blacklisted. +// If the entry has an alias mapping (aliases[entry] exists), set Alias to the +// provider-specific ID so callers can route to the right model. +// +// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"} +// "my-gpt4" not in included → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// "gpt-3.5" not in included → add {ID:"openai/gpt-3.5"} +// +// Case B — allowlist wildcard (*) only: +// +// We don't know all model names (no explicit list), so we only backfill entries +// that were explicitly configured via aliases and not yet matched from the API. +// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard. +// +// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included +// → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// +// Blacklist always wins — nothing blacklisted is added in either case. +func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model { + var result []schemas.Model + + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + // Case A: backfill explicit allowlist entries not yet matched. + for _, entry := range p.AllowedModels { + if included[strings.ToLower(entry)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(entry, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + m := schemas.Model{ + ID: string(p.ProviderKey) + "/" + entry, + Name: schemas.Ptr(ToDisplayName(entry)), + } + // If this allowlist entry has an alias, surface the provider-specific ID. + for aliasKey, providerID := range p.Aliases { + if matches(entry, aliasKey, p.MatchFns) { + m.Alias = schemas.Ptr(providerID) + break + } + } + result = append(result, m) + } + return result + } + + // Case B: wildcard allowlist — backfill only explicitly configured aliases. + if !p.Unfiltered && len(p.Aliases) > 0 { + for aliasKey, providerID := range p.Aliases { + if included[strings.ToLower(aliasKey)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(aliasKey, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + result = append(result, schemas.Model{ + ID: string(p.ProviderKey) + "/" + aliasKey, + Name: schemas.Ptr(ToDisplayName(aliasKey)), + Alias: schemas.Ptr(providerID), + }) + } + } + + return result +} diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index efd1ea992b..748c83210b 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -178,12 +178,12 @@ func MakeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *f } // Check for timeout errors first before checking net.OpError to avoid misclassification if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check if error implements net.Error and has Timeout() == true var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -1043,7 +1043,7 @@ func MergeExtraParamsIntoJSON(jsonBody []byte, extraParams map[string]interface{ } // CheckContextAndGetRequestBody checks if the raw request body should be used, and returns it if it exists. -func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter, providerType schemas.ModelProvider) ([]byte, *schemas.BifrostError) { +func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter) ([]byte, *schemas.BifrostError) { if IsLargePayloadPassthroughEnabled(ctx) { return nil, nil } @@ -1052,15 +1052,15 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette if !ok { convertedBody, err := requestConverter() if err != nil { - return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err) } if convertedBody == nil { - return nil, NewBifrostOperationError("request body is not provided", nil, providerType) + return nil, NewBifrostOperationError("request body is not provided", nil) } jsonBody, err := MarshalSortedIndent(convertedBody, "", " ") if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -1070,7 +1070,7 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette // tool schemas and other order-sensitive JSON structures. jsonBody, err = MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -1367,10 +1367,6 @@ func NewUnsupportedOperationError(requestType schemas.RequestType, providerName Message: fmt.Sprintf("%s is not supported by %s provider", requestType, providerName), Code: schemas.Ptr("unsupported_operation"), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: requestType, - }, } } @@ -1593,37 +1589,31 @@ func ParseJSONL(data []byte, parseLine func(line []byte) error) JSONLParseResult // NewConfigurationError creates a standardized error for configuration errors. // This helper reduces code duplication across providers that have configuration errors. -func NewConfigurationError(message string, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewConfigurationError(message string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostOperationError creates a standardized error for bifrost operation errors. // This helper reduces code duplication across providers that have bifrost operation errors. -func NewBifrostOperationError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostOperationError(message string, err error) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: message, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostTimeoutError creates a standardized error for provider request timeout errors. // Sets StatusCode to 504 (Gateway Timeout) and Error.Type to RequestTimedOut, // consistent with HandleStreamTimeout for streaming requests. -func NewBifrostTimeoutError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostTimeoutError(message string, err error) *schemas.BifrostError { statusCode := 504 errorType := schemas.RequestTimedOut return &schemas.BifrostError{ @@ -1634,15 +1624,12 @@ func NewBifrostTimeoutError(message string, err error, providerType schemas.Mode Type: &errorType, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewProviderAPIError creates a standardized error for provider API errors. // This helper reduces code duplication across providers that have provider API errors. -func NewProviderAPIError(message string, err error, statusCode int, providerType schemas.ModelProvider, errorType *string, eventID *string) *schemas.BifrostError { +func NewProviderAPIError(message string, err error, statusCode int, errorType *string, eventID *string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, @@ -1653,61 +1640,43 @@ func NewProviderAPIError(message string, err error, statusCode int, providerType Error: err, Type: errorType, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } -// RequestMetadata contains metadata about a request for error reporting. -// This struct is used to pass request context to parseError functions. -type RequestMetadata struct { - Provider schemas.ModelProvider - Model string - RequestType schemas.RequestType -} - -// ShouldSendBackRawRequest checks if the raw request should be captured. -// Context overrides are intentionally restricted to asymmetric behavior: a context value can only -// promote false→true and will not override a true config to false, avoiding accidental suppression. -// Both full send-back mode and logging-only mode (store_raw_request_response) set -// BifrostContextKeySendBackRawRequest=true in the request context so a single flag is checked here. -// In logging-only mode the payload is stripped before the response reaches the client. +// ShouldSendBackRawRequest checks if raw request bytes should be captured. +// bifrost.go always writes BifrostContextKeyCaptureRawRequest before provider dispatch, +// combining provider config, per-request overrides, and store_raw_request_response. +// The default parameter is a fallback for callers outside the normal bifrost dispatch path. func ShouldSendBackRawRequest(ctx context.Context, defaultSendBackRawRequest bool) bool { - if sendBackRawRequest, ok := ctx.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok && sendBackRawRequest { - return sendBackRawRequest + if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawRequest).(bool); ok { + return capture } return defaultSendBackRawRequest } -// ShouldSendBackRawResponse checks if the raw response should be captured. -// Context overrides are intentionally restricted to asymmetric behavior: a context value can only -// promote false→true and will not override a true config to false, avoiding accidental suppression. -// Both full send-back mode and logging-only mode (store_raw_request_response) set -// BifrostContextKeySendBackRawResponse=true in the request context so a single flag is checked here. -// In logging-only mode the payload is stripped before the response reaches the client. +// ShouldSendBackRawResponse checks if raw response bytes should be captured. +// bifrost.go always writes BifrostContextKeyCaptureRawResponse before provider dispatch, +// combining provider config, per-request overrides, and store_raw_request_response. +// The default parameter is a fallback for callers outside the normal bifrost dispatch path. func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse bool) bool { - if sendBackRawResponse, ok := ctx.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok && sendBackRawResponse { - return sendBackRawResponse + if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawResponse).(bool); ok { + return capture } return defaultSendBackRawResponse } // SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event. -func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { firstChunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: 0, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 0, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 0, + Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: firstChunk, } @@ -1715,20 +1684,17 @@ func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner } // SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event -func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { chunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: 1, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 1, + Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: chunk, } @@ -1736,13 +1702,14 @@ func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunn } // BuildClientStreamChunk constructs a BifrostStreamChunk from post-hook results. -// It never mutates the shared processedResponse or processedError objects — when in -// logging-only mode (BifrostContextKeyRawRequestResponseForLogging) it shallow-copies -// each inner response struct and the BifrostError, nils only the raw fields on those -// copies, and returns them as the outgoing chunk. This is safe for concurrent PostLLMHook -// goroutines that still hold references to the originals. +// It never mutates the shared processedResponse or processedError objects — when raw fields +// need to be stripped (captured for storage but not for send-back), it shallow-copies each +// inner response struct and nils only the appropriate per-side field on those copies. +// This is safe for concurrent PostLLMHook goroutines that still hold references to the originals. func BuildClientStreamChunk(ctx context.Context, processedResponse *schemas.BifrostResponse, processedError *schemas.BifrostError) *schemas.BifrostStreamChunk { - dropRaw, _ := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool) + dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) + dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) + drop := dropReq || dropResp streamResponse := &schemas.BifrostStreamChunk{} if processedResponse != nil { streamResponse.BifrostTextCompletionResponse = processedResponse.TextCompletionResponse @@ -1753,51 +1720,79 @@ func BuildClientStreamChunk(ctx context.Context, processedResponse *schemas.Bifr streamResponse.BifrostImageGenerationStreamResponse = processedResponse.ImageGenerationStreamResponse // Strip raw fields from client-facing copies without mutating the shared objects // that PostLLMHook goroutines may still be reading. - if dropRaw { + if drop { if streamResponse.BifrostTextCompletionResponse != nil { cp := *streamResponse.BifrostTextCompletionResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostTextCompletionResponse = &cp } if streamResponse.BifrostChatResponse != nil { cp := *streamResponse.BifrostChatResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostChatResponse = &cp } if streamResponse.BifrostResponsesStreamResponse != nil { cp := *streamResponse.BifrostResponsesStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostResponsesStreamResponse = &cp } if streamResponse.BifrostSpeechStreamResponse != nil { cp := *streamResponse.BifrostSpeechStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostSpeechStreamResponse = &cp } if streamResponse.BifrostTranscriptionStreamResponse != nil { cp := *streamResponse.BifrostTranscriptionStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostTranscriptionStreamResponse = &cp } if streamResponse.BifrostImageGenerationStreamResponse != nil { cp := *streamResponse.BifrostImageGenerationStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostImageGenerationStreamResponse = &cp } } } if processedError != nil { - if dropRaw { + if drop { // Strip raw fields from a client-facing copy without mutating the shared error object. errCopy := *processedError - errCopy.ExtraFields.RawRequest = nil - errCopy.ExtraFields.RawResponse = nil + if dropReq { + errCopy.ExtraFields.RawRequest = nil + } + if dropResp { + errCopy.ExtraFields.RawResponse = nil + } streamResponse.BifrostError = &errCopy } else { streamResponse.BifrostError = processedError @@ -2050,9 +2045,6 @@ func HandleStreamCancellation( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2068,11 +2060,6 @@ func HandleStreamCancellation( Message: "Request cancelled: client disconnected", Type: schemas.Ptr(schemas.RequestCancelled), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2091,9 +2078,6 @@ func HandleStreamTimeout( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2109,11 +2093,6 @@ func HandleStreamTimeout( Message: "Request timed out: deadline exceeded", Type: schemas.Ptr(schemas.RequestTimedOut), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2129,25 +2108,16 @@ func ProcessAndSendError( postHookRunner schemas.PostHookRunner, err error, responseChan chan *schemas.BifrostStreamChunk, - requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, logger schemas.Logger, ) { // Send scanner error through channel - bifrostError := - &schemas.BifrostError{ - IsBifrostError: true, - Error: &schemas.ErrorField{ - Message: fmt.Sprintf("Error reading stream: %v", err), - Error: err, - }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - }, - } + bifrostError := &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: fmt.Sprintf("Error reading stream: %v", err), + Error: err, + }, + } processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) if HandleStreamControlSkip(processedError) { @@ -2179,8 +2149,6 @@ func CreateBifrostTextCompletionChunkResponse( finishReason *string, currentChunkIndex int, requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, ) *schemas.BifrostTextCompletionResponse { response := &schemas.BifrostTextCompletionResponse{ ID: id, @@ -2193,10 +2161,7 @@ func CreateBifrostTextCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2208,14 +2173,15 @@ func CreateBifrostChatCompletionChunkResponse( usage *schemas.BifrostLLMUsage, finishReason *string, currentChunkIndex int, - requestType schemas.RequestType, - providerName schemas.ModelProvider, model string, + created int, ) *schemas.BifrostChatResponse { response := &schemas.BifrostChatResponse{ - ID: id, - Object: "chat.completion.chunk", - Usage: usage, + ID: id, + Model: model, + Created: created, + Object: "chat.completion.chunk", + Usage: usage, Choices: []schemas.BifrostResponseChoice{ { FinishReason: finishReason, @@ -2225,10 +2191,7 @@ func CreateBifrostChatCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2316,7 +2279,7 @@ func GetBifrostResponseForStreamResponse( transcriptionStreamResponse *schemas.BifrostTranscriptionStreamResponse, imageGenerationStreamResponse *schemas.BifrostImageGenerationStreamResponse, ) *schemas.BifrostResponse { - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{} switch { @@ -2398,10 +2361,7 @@ func aggregateListModelsResponses(responses []*schemas.BifrostListModelsResponse // extractSuccessfulListModelsResponses extracts successful responses from a results channel // and tracks per-key status information. This utility reduces code duplication across providers // for handling multi-key ListModels requests. -func extractSuccessfulListModelsResponses( - results chan schemas.ListModelsByKeyResult, - providerName schemas.ModelProvider, -) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { +func extractSuccessfulListModelsResponses(results chan schemas.ListModelsByKeyResult, provider schemas.ModelProvider) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { var successfulResponses []*schemas.BifrostListModelsResponse var keyStatuses []schemas.KeyStatus var lastError *schemas.BifrostError @@ -2419,7 +2379,7 @@ func extractSuccessfulListModelsResponses( getLogger().Warn(fmt.Sprintf("failed to list models with key %s: %s", result.KeyID, errMsg)) keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusListModelsFailed, Error: result.Err, }) @@ -2429,7 +2389,7 @@ func extractSuccessfulListModelsResponses( keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusSuccess, }) successfulResponses = append(successfulResponses, result.Response) @@ -2444,10 +2404,6 @@ func extractSuccessfulListModelsResponses( Error: &schemas.ErrorField{ Message: "all keys failed to list models", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }, } } @@ -2505,6 +2461,21 @@ func HandleMultipleListModelsRequests( wg.Add(1) go func(k schemas.Key) { defer wg.Done() + // Should never panic, but if it does, we need to handle it gracefully + defer func() { + if r := recover(); r != nil { + getLogger().Error("panic in listModelsByKey for key %s (%s): %v", k.Name, k.ID, r) + results <- schemas.ListModelsByKeyResult{ + Err: &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: "internal error while listing models for key", + }, + }, + KeyID: k.ID, + } + } + }() resp, bifrostErr := listModelsByKey(ctx, k, request) results <- schemas.ListModelsByKeyResult{Response: resp, Err: bifrostErr, KeyID: k.ID} }(key) @@ -2530,8 +2501,6 @@ func HandleMultipleListModelsRequests( // Set ExtraFields latency := time.Since(startTime) - response.ExtraFields.Provider = request.Provider - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil @@ -2691,10 +2660,10 @@ func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostRe if accumulatedResp != nil { // Use accumulated response for attributes (includes full content, tool calls, etc.) - tracer.PopulateLLMResponseAttributes(handle, accumulatedResp, err) + tracer.PopulateLLMResponseAttributes(ctx, handle, accumulatedResp, err) } else if result != nil { // Fall back to final chunk if no accumulated data (shouldn't happen normally) - tracer.PopulateLLMResponseAttributes(handle, result, err) + tracer.PopulateLLMResponseAttributes(ctx, handle, result, err) } // Finalize aggregated post-hook spans before ending the LLM span diff --git a/core/providers/utils/utils_test.go b/core/providers/utils/utils_test.go index 5ca567fce0..e832980f4f 100644 --- a/core/providers/utils/utils_test.go +++ b/core/providers/utils/utils_test.go @@ -1107,7 +1107,8 @@ func TestBuildClientStreamChunk_ImageGenerationStripping(t *testing.T) { t.Run("logging-only: raw fields stripped from image gen chunk, original preserved", func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) - ctx.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true) chunk := BuildClientStreamChunk(ctx, response, nil) if chunk.BifrostImageGenerationStreamResponse == nil { @@ -1148,9 +1149,9 @@ func TestBuildClientStreamChunk_ImageGenerationStripping(t *testing.T) { } // TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk verifies -// that when BifrostContextKeyRawRequestResponseForLogging is set, ProcessAndSendResponse -// strips RawRequest and RawResponse from the outgoing stream chunk, while leaving other -// ExtraFields intact. It also verifies that the original BifrostResponse is not mutated +// that when drop-raw context flags are set, ProcessAndSendResponse strips RawRequest and +// RawResponse from the outgoing stream chunk, while leaving other ExtraFields intact. +// It also verifies that the original BifrostResponse is not mutated // (shared object safety for PostLLMHook goroutines). func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk(t *testing.T) { rawReq := json.RawMessage(`{"model":"gpt-4","messages":[]}`) @@ -1177,7 +1178,8 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChu t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) if tt.loggingOnly { - ctx.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true) } response := &schemas.BifrostResponse{ @@ -1237,9 +1239,9 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChu } // TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk verifies -// that when BifrostContextKeyRawRequestResponseForLogging is set, raw data is stripped -// from BifrostError payloads embedded in stream chunks, without mutating the shared -// BifrostError object (shared object safety for PostLLMHook goroutines). +// that when drop-raw context flags are set, raw data is stripped from BifrostError +// payloads embedded in stream chunks, without mutating the shared BifrostError object +// (shared object safety for PostLLMHook goroutines). func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk(t *testing.T) { rawReq := json.RawMessage(`{"model":"gpt-4"}`) rawResp := json.RawMessage(`{"error":"rate limit exceeded"}`) @@ -1265,7 +1267,8 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk( t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) if tt.loggingOnly { - ctx.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true) } // Use a postHookRunner that converts the response to a BifrostError with raw data diff --git a/core/providers/vertex/embedding.go b/core/providers/vertex/embedding.go index 0fc0ad598f..54662f50fe 100644 --- a/core/providers/vertex/embedding.go +++ b/core/providers/vertex/embedding.go @@ -110,8 +110,6 @@ func (response *VertexEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.B Data: embeddings, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: schemas.Vertex, }, } } diff --git a/core/providers/vertex/errors.go b/core/providers/vertex/errors.go index 6b255835d4..e0ed7f1d3d 100644 --- a/core/providers/vertex/errors.go +++ b/core/providers/vertex/errors.go @@ -10,25 +10,13 @@ import ( "github.com/valyala/fasthttp" ) -func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { - var providerName schemas.ModelProvider - if meta != nil { - providerName = meta.Provider - } - +func parseVertexError(resp *fasthttp.Response) *schemas.BifrostError { var openAIErr schemas.BifrostError var vertexErr []VertexError decodedBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) return bifrostErr } @@ -42,13 +30,6 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseEmpty, }, } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } return bifrostErr } @@ -61,26 +42,20 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseHTML, Error: errors.New(string(decodedBody)), }, - } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + ExtraFields: schemas.BifrostErrorExtraFields{ + RawResponse: string(decodedBody), + }, } return bifrostErr } createError := func(message string) *schemas.BifrostError { - bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), providerName, nil, nil) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), nil, nil) + var rawResponse interface{} + if err := sonic.Unmarshal(decodedBody, &rawResponse); err != nil { + rawResponse = string(decodedBody) } + bifrostErr.ExtraFields.RawResponse = rawResponse return bifrostErr } @@ -93,14 +68,7 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada // Try VertexValidationError format (validation errors from Mistral endpoint) var validationErr VertexValidationError if err := sonic.Unmarshal(decodedBody, &validationErr); err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) return bifrostErr } if len(validationErr.Detail) > 0 { diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 28b5598022..2fbe83979d 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -1,12 +1,10 @@ package vertex import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "golang.org/x/text/cases" - "golang.org/x/text/language" ) // VertexRankRequest represents the Discovery Engine rank API request. @@ -56,49 +54,6 @@ type vertexRerankOptions struct { UserLabels map[string]string } -// formatDeploymentName converts a deployment alias into a human-readable name. -// It splits the alias by "-" or "_", capitalizes each word, and joins them with spaces. -// Example: "gemini-pro" → "Gemini Pro", "claude_3_opus" → "Claude 3 Opus" -func formatDeploymentName(alias string) string { - caser := cases.Title(language.English) - - // Try splitting by hyphen first, then underscore - var parts []string - if strings.Contains(alias, "-") { - parts = strings.Split(alias, "-") - } else if strings.Contains(alias, "_") { - parts = strings.Split(alias, "_") - } else { - // No delimiter found, just capitalize the whole string - return caser.String(strings.ToLower(alias)) - } - - // Capitalize each part - for i, part := range parts { - if part != "" { - parts[i] = caser.String(strings.ToLower(part)) - } - } - - return strings.Join(parts, " ") -} - -// findDeploymentMatch finds a matching deployment value in the deployments map. -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, customModelID string) (deploymentValue, alias string) { - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == customModelID { - return depValue, aliasKey - } - } - // Check exact match by alias/key - if deployment, ok := deployments[customModelID]; ok { - return deployment, customModelID - } - return "", "" -} - // ToBifrostListModelsResponse converts a Vertex AI list models response to Bifrost's format. // It processes both custom models (from the API response) and non-custom models (from deployments and allowedModels). // @@ -114,7 +69,7 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d // - If allowedModels is empty, all models are allowed // - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included // - Deployments map is used to match model IDs to aliases and filter accordingly -func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -123,10 +78,22 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod Data: make([]schemas.Model, 0, len(response.Models)), } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) - // First pass: Process all models from the Vertex AI API response (custom models) + // Process all models from the Vertex AI API response (custom deployed models). + // The model ID is extracted from the endpoint URL last segment. for _, model := range response.Models { if len(model.DeployedModels) == 0 { continue @@ -142,110 +109,28 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod continue } - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - var deploymentValue, deploymentAlias string - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) + for _, result := range pipeline.FilterModel(customModelID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { + continue } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - shouldFilter = !slices.Contains(allowedModels, customModelID) - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - modelID := customModelID - - if !unfiltered && (slices.Contains(blacklistedModels, customModelID) || slices.Contains(blacklistedModels, deploymentAlias)) { - continue - } - - modelEntry := schemas.Model{ - ID: string(schemas.Vertex) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.VersionCreateTime.Unix()), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Vertex) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelEntry.ID] = true - } - } - - // Second pass: Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + alias - if addedModelIDs[modelID] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if slices.Contains(blacklistedModels, alias) { - continue - } - - modelName := formatDeploymentName(alias) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.VersionCreateTime.Unix()), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[resolvedKey] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true } } - // Third pass: Backfill allowed models that were not in the response or deployments - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + allowedModel - if addedModelIDs[modelID] { - continue - } - if slices.Contains(blacklistedModels, allowedModel) { - continue - } - - modelName := formatDeploymentName(allowedModel) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) bifrostResponse.NextPageToken = response.NextPageToken @@ -254,7 +139,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format. // This is for foundation models from the Model Garden (publishers.models.list endpoint). -func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -263,8 +148,19 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a Data: make([]schemas.Model, 0, len(response.PublisherModels)), } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) for _, model := range response.PublisherModels { // Extract model ID from name (format: "publishers/google/models/gemini-1.5-pro") @@ -273,36 +169,28 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a continue } - // Filter based on allowedModels if specified - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, modelID) { - continue - } - - // Skip if already added (shouldn't happen, but safety check) - fullModelID := string(schemas.Vertex) + "/" + modelID - if addedModelIDs[fullModelID] { - continue - } - - // Extract display name from supported actions if available - displayName := modelID - if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { - displayName = model.SupportedActions.Deploy.ModelDisplayName - } - - modelEntry := schemas.Model{ - ID: fullModelID, - Name: schemas.Ptr(displayName), + for _, result := range pipeline.FilterModel(modelID) { + // Extract display name from supported actions if available + displayName := result.ResolvedID + if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { + displayName = model.SupportedActions.Deploy.ModelDisplayName + } + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(displayName), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[fullModelID] = true } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + bifrostResponse.NextPageToken = response.NextPageToken return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/vertex/rerank.go b/core/providers/vertex/rerank.go index 74372658b2..b06430fcac 100644 --- a/core/providers/vertex/rerank.go +++ b/core/providers/vertex/rerank.go @@ -83,7 +83,7 @@ func getVertexRerankOptions(projectID string, params *schemas.RerankParameters) } // ToVertexRankRequest converts a Bifrost rerank request to Discovery Engine rank API format. -func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployment string, options *vertexRerankOptions) (*VertexRankRequest, error) { +func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vertexRerankOptions) (*VertexRankRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost rerank request is nil") } @@ -132,7 +132,7 @@ func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployme rankRequest.TopN = &topN } - if trimmedModel := strings.TrimSpace(modelDeployment); trimmedModel != "" { + if trimmedModel := strings.TrimSpace(bifrostReq.Model); trimmedModel != "" { rankRequest.Model = &trimmedModel } diff --git a/core/providers/vertex/rerank_test.go b/core/providers/vertex/rerank_test.go index afd8ed225e..3f2efcec52 100644 --- a/core/providers/vertex/rerank_test.go +++ b/core/providers/vertex/rerank_test.go @@ -42,7 +42,6 @@ func TestToVertexRankRequest(t *testing.T) { TopN: schemas.Ptr(10), }, }, - "semantic-ranker-default@latest", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, @@ -77,7 +76,6 @@ func TestToVertexRankRequestTooManyRecords(t *testing.T) { Query: "q", Documents: docs, }, - "", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, diff --git a/core/providers/vertex/types.go b/core/providers/vertex/types.go index 97d6de7fa2..bbdb89d17f 100644 --- a/core/providers/vertex/types.go +++ b/core/providers/vertex/types.go @@ -192,23 +192,23 @@ type VertexModelLabels struct { // These types are for the publishers.models.list endpoint (Model Garden) type VertexPublisherModel struct { - Name string `json:"name"` - VersionID string `json:"versionId"` - OpenSourceCategory string `json:"openSourceCategory"` - LaunchStage string `json:"launchStage"` - VersionState string `json:"versionState"` - PublisherModelTemplate string `json:"publisherModelTemplate"` - SupportedActions *VertexPublisherModelActions `json:"supportedActions"` + Name string `json:"name"` + VersionID string `json:"versionId"` + OpenSourceCategory string `json:"openSourceCategory"` + LaunchStage string `json:"launchStage"` + VersionState string `json:"versionState"` + PublisherModelTemplate string `json:"publisherModelTemplate"` + SupportedActions *VertexPublisherModelActions `json:"supportedActions"` } type VertexPublisherModelActions struct { - OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` - OpenGenie *VertexPublisherModelURI `json:"openGenie"` - OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` - OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` - OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` - Deploy *VertexPublisherModelDeploy `json:"deploy"` - OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` + OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` + OpenGenie *VertexPublisherModelURI `json:"openGenie"` + OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` + OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` + OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` + Deploy *VertexPublisherModelDeploy `json:"deploy"` + OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` } type VertexPublisherModelURI struct { diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index 2e4e225f4d..0dfda6763d 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader — skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -26,74 +26,74 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Add max_tokens if not present if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305) jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } // Add anthropic_version if not present if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Validate tools are supported by Vertex if request.Params != nil && request.Params.Tools != nil { if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil { - return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil) } } // Convert request to Anthropic format reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } reqBody.Model = deployment @@ -109,44 +109,44 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add anthropic_version if not present (using sjson to preserve order) if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Remove model field for Vertex API (it's in URL) jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -178,29 +178,25 @@ func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID // buildResponseFromConfig builds a list models response from configured deployments and allowedModels. // This is used when the user has explicitly configured which models they want to use. -func buildResponseFromConfig(deployments map[string]string, allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { +func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse { response := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } + if blacklistedModels.IsBlockAll() { + return response + } + addedModelIDs := make(map[string]bool) - // Build allowlist set for O(1) lookup - allowedSet := make(map[string]bool, len(allowedModels)) - for _, m := range allowedModels { - allowedSet[m] = true - } - blacklistedSet := make(map[string]bool, len(blacklistedModels)) - for _, m := range blacklistedModels { - blacklistedSet[m] = true - } + restrictAllowed := allowedModels.IsRestricted() // First add models from deployments (filtered by allowedModels when set) for alias, deploymentValue := range deployments { - if len(allowedSet) > 0 && !allowedSet[alias] { + if restrictAllowed && !allowedModels.Contains(alias) { continue } - if len(blacklistedSet) > 0 && blacklistedSet[alias] { + if blacklistedModels.IsBlocked(alias) { continue } modelID := string(schemas.Vertex) + "/" + alias @@ -208,28 +204,31 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels []stri continue } - modelName := formatDeploymentName(alias) + modelName := providerUtils.ToDisplayName(alias) modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + ID: modelID, + Name: schemas.Ptr(modelName), + Alias: schemas.Ptr(deploymentValue), } response.Data = append(response.Data, modelEntry) addedModelIDs[modelID] = true } - // Then add models from allowedModels that aren't already in deployments + // Then add models from allowedModels that aren't already in deployments (only when restricted) + if !restrictAllowed { + return response + } for _, allowedModel := range allowedModels { - if len(blacklistedSet) > 0 && blacklistedSet[allowedModel] { - continue - } modelID := string(schemas.Vertex) + "/" + allowedModel if addedModelIDs[modelID] { continue } + if blacklistedModels.IsBlocked(allowedModel) { + continue + } - modelName := formatDeploymentName(allowedModel) + modelName := providerUtils.ToDisplayName(allowedModel) modelEntry := schemas.Model{ ID: modelID, Name: schemas.Ptr(modelName), diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 1f5682d692..9a4792eb6a 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -114,9 +114,6 @@ const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" // It uses the JWT config if auth credentials are provided. // It returns an error if the token source creation fails. func getAuthTokenSource(key schemas.Key) (oauth2.TokenSource, error) { - if key.VertexKeyConfig == nil { - return nil, fmt.Errorf("vertex key config is not set") - } authCredentials := key.VertexKeyConfig.AuthCredentials var tokenSource oauth2.TokenSource if authCredentials.GetValue() == "" { @@ -176,23 +173,21 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // 1. If deployments or allowedModels are configured, return those (no API call needed) // 2. Otherwise, fetch from the publishers.models.list API endpoint (Model Garden) func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - deployments := key.VertexKeyConfig.Deployments + deployments := key.Aliases allowedModels := key.Models + if !request.Unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || key.BlacklistedModels.IsBlockAll()) { + return &schemas.BifrostListModelsResponse{Data: make([]schemas.Model, 0)}, nil + } + // If deployments or allowedModels are configured, return those directly without API call // Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved - if !request.Unfiltered && (len(deployments) > 0 || len(allowedModels) > 0) { + if !request.Unfiltered && (len(deployments) > 0 || allowedModels.IsRestricted()) { return buildResponseFromConfig(deployments, allowedModels, key.BlacklistedModels), nil } @@ -213,11 +208,11 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err) } // Iterate over all supported Vertex publishers to include Google, Anthropic, and Mistral models @@ -246,13 +241,14 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) if bifrostErr != nil { wait() + respBody := append([]byte(nil), resp.Body()...) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) // Non-Google publishers may not be available in all regions; skip on error if publisher != "google" { break } - return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -280,9 +276,9 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key var errorResp VertexError if err := sonic.Unmarshal(respBody, &errorResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, schemas.Vertex, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Vertex's publisher models response @@ -322,7 +318,7 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key PublisherModels: allPublisherModels, } - response := aggregatedResponse.ToBifrostListModelsResponse(nil, key.BlacklistedModels, request.Unfiltered) + response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequests @@ -368,18 +364,6 @@ func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -389,7 +373,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use centralized Anthropic converter reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request) if convErr != nil { @@ -399,7 +383,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) // Marshal to JSON bytes, preserving struct field order @@ -426,7 +409,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to delete model field: %w", err) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody, err := gemini.ToGeminiChatCompletionRequest(request) if err != nil { return nil, err @@ -435,7 +418,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes @@ -450,7 +432,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Marshal to JSON bytes rawBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { @@ -465,26 +446,26 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Remap unsupported tool versions for Vertex (handles raw passthrough bodies) - if schemas.IsAnthropicModel(deployment) && jsonBody != nil { + if schemas.IsAnthropicModel(request.Model) && jsonBody != nil { remappedBody, remapErr := anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } jsonBody = remappedBody } @@ -493,43 +474,43 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsAnthropicModel(deployment) { + } else if schemas.IsAnthropicModel(request.Model) { // Claude models use Anthropic publisher if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsMistralModel(deployment) { + } else if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with rawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { // Gemini models support api key if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } else { if region == "global" { @@ -564,11 +545,11 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -597,14 +578,10 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -613,16 +590,13 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Create response object from pool anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) @@ -636,17 +610,9 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostChatResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), - } - - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment + Latency: latency.Milliseconds(), + ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } - response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -659,7 +625,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -668,12 +634,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } response := geminiResponse.ToBifrostChatResponse() - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -695,12 +655,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -723,35 +677,17 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Returns a channel of BifrostStreamChunk objects for streaming results or an error if the request fails. func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use Anthropic-style streaming for Claude models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -766,8 +702,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment - reqBody.Stream = schemas.Ptr(true) + reqBody.Stream = new(true) // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) @@ -803,7 +738,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -813,15 +748,15 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext var remapErr error jsonData, remapErr = anthropic.RemapRawToolVersionsForProvider(jsonData, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } } var completeURL string if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -834,11 +769,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken @@ -855,15 +790,10 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), providerName, postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { // Use Gemini-style streaming for Gemini models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -876,12 +806,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if reqBody == nil { return nil, fmt.Errorf("chat completion input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -894,12 +823,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { @@ -918,11 +847,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } @@ -940,7 +869,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -949,12 +878,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsMistralModel(deployment) { + if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with streamRawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, request.Model) } } else { // Other models use OpenAPI endpoint for gemini models @@ -974,22 +903,17 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } authHeader = map[string]string{ "Authorization": "Bearer " + token.AccessToken, } } - postRequestConverter := func(reqBody *openai.OpenAIChatRequest) *openai.OpenAIChatRequest { - reqBody.Model = deployment - return reqBody - } - // Use shared OpenAI streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, @@ -1005,8 +929,8 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -1014,40 +938,28 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Responses performs a responses request to the Vertex API. func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Claude models use Anthropic publisher var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } // Create HTTP request for streaming @@ -1068,11 +980,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1100,14 +1012,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1115,9 +1023,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1137,13 +1042,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem response := anthropicResponse.ToBifrostResponsesResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1154,12 +1055,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse } - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -1171,24 +1069,23 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -1198,11 +1095,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":generateContent") + url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":generateContent") // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1227,11 +1124,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -1260,14 +1157,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1275,9 +1168,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1292,16 +1182,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := geminiResponse.ToResponsesBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse @@ -1319,52 +1202,33 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response, nil } } // ResponsesStream performs a streaming responses request to the Vertex API. func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -1377,22 +1241,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Anthropic return anthropic.HandleAnthropicResponsesStream( ctx, @@ -1406,23 +1262,18 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } // Use Gemini-style streaming for Gemini models @@ -1437,12 +1288,11 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1455,12 +1305,12 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery) @@ -1478,23 +1328,15 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Gemini return gemini.HandleGeminiResponsesStream( ctx, @@ -1508,7 +1350,7 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -1526,18 +1368,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // All Vertex AI embedding models use the same response format regardless of the model type. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1546,24 +1384,19 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVertexEmbeddingRequest(request), nil }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } - deployment := provider.getModelDeployment(key, request.Model) - - // Remove google/ prefix from deployment - deployment = strings.TrimPrefix(deployment, "google/") - // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Build the native Vertex embedding API endpoint - url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":predict") + url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":predict") // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1586,11 +1419,11 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1626,7 +1459,7 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Try to parse Vertex's error format var vertexError map[string]interface{} if err := sonic.Unmarshal(errBody, &vertexError); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errorObj, exists := vertexError["error"]; exists { @@ -1640,10 +1473,10 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem } } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), schemas.Vertex, nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1651,9 +1484,6 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostEmbeddingResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1663,28 +1493,21 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Parse Vertex's native embedding response using typed response var vertexResponse VertexEmbeddingResponse if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Use centralized Vertex converter bifrostResponse := vertexResponse.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if bifrostResponse.ExtraFields.ModelRequested != deployment { - bifrostResponse.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponseMap map[string]interface{} if err := sonic.Unmarshal(resp.Body(), &rawResponseMap); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ExtraFields.RawResponse = rawResponseMap } @@ -1699,30 +1522,23 @@ func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas. // Rerank performs a rerank request using Vertex Discovery Engine ranking API. func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } options, err := getVertexRerankOptions(projectID, request.Params) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } - modelDeployment := provider.getModelDeployment(key, request.Model) jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToVertexRankRequest(request, modelDeployment, options) + return ToVertexRankRequest(request, options) }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1748,11 +1564,11 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1780,11 +1596,7 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. } errorMessage := parseDiscoveryEngineErrorMessage(resp.Body()) - parsedError := parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + parsedError := parseVertexError(resp) if strings.TrimSpace(errorMessage) != "" { shouldOverride := parsedError == nil || @@ -1794,19 +1606,14 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. parsedError.Error.Message == schemas.ErrProviderResponseUnmarshal if shouldOverride { - parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), providerName, nil, nil) - parsedError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, - } + parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil) } } return nil, providerUtils.EnrichError(ctx, parsedError, jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1815,9 +1622,6 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1833,16 +1637,9 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments bifrostResponse, err := vertexResponse.ToBifrostRerankResponse(request.Documents, returnDocuments) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - if request.Model != modelDeployment { - bifrostResponse.ExtraFields.ModelDeployment = modelDeployment - } - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -1878,21 +1675,9 @@ func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, } func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1903,13 +1688,12 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -1917,7 +1701,7 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") @@ -1937,58 +1721,58 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used for fine-tuned models to pass the API key in the query string authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { // Imagen models are published models, use publishers/google/models path if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -2015,11 +1799,11 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2048,14 +1832,10 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2063,16 +1843,13 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2085,12 +1862,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2113,12 +1884,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2142,20 +1907,9 @@ func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContex // ImageEdit edits images for the given input text(s) using Vertex AI. // Returns a BifrostResponse containing the images and any error that occurred. func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2166,13 +1920,12 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -2180,7 +1933,7 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") @@ -2200,19 +1953,19 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -2221,27 +1974,27 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -2267,11 +2020,11 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2299,14 +2052,10 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2314,16 +2063,13 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2336,12 +2082,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2364,12 +2104,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2400,18 +2134,9 @@ func (provider *VertexProvider) ImageVariation(ctx *schemas.BifrostContext, key func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, bifrostReq.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Only Gemini models support video generation in Vertex - if !schemas.IsVeoModel(deployment) && !schemas.IsAllDigitsASCII(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", deployment), providerName) + if !schemas.IsVeoModel(bifrostReq.Model) && !schemas.IsAllDigitsASCII(bifrostReq.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", bifrostReq.Model)) } // Convert Bifrost request to Gemini format (reusing Gemini converters) @@ -2421,7 +2146,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2429,12 +2153,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used to pass the API key in the query string @@ -2445,12 +2169,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(bifrostReq.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini video generation using predictLongRunning - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":predictLongRunning") + completeURL := getCompleteURLForGeminiEndpoint(bifrostReq.Model, region, projectID, projectNumber, ":predictLongRunning") // Create HTTP request req := fasthttp.AcquireRequest() @@ -2469,11 +2193,11 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2493,17 +2217,13 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: bifrostReq.Model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var operation gemini.GenerateVideosOperation @@ -2519,12 +2239,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = bifrostReq.Model - if bifrostReq.Model != deployment { - bifrostResp.ExtraFields.ModelDeployment = deployment - } - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2540,18 +2254,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // VideoRetrieve retrieves the status of a video generation operation. // Uses the fetchPredictOperation endpoint for Vertex AI. func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Construct base URL based on region @@ -2567,12 +2275,12 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // projects/PROJECT_ID/locations/REGION/publishers/google/models/MODEL_ID/operations/OPERATION_ID // We need to extract the model path from it to construct the fetchPredictOperation endpoint // Extract: projects/.../models/MODEL_ID from the operation name - taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, providerName) + taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, provider.GetProviderKey()) var modelPath string if idx := strings.Index(taskID, "/operations/"); idx != -1 { modelPath = taskID[:idx] } else { - return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil) } // Construct the URL: https://REGION-aiplatform.googleapis.com/v1/{modelPath}:fetchPredictOperation @@ -2587,7 +2295,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Create request body with operation name (using sjson to avoid map marshaling) jsonBody, err := providerUtils.SetJSONField([]byte(`{}`), "operationName", taskID) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err) } // Create HTTP request @@ -2607,11 +2315,11 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2631,10 +2339,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -2648,10 +2353,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if sendBackRawResponse { @@ -2665,9 +2368,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // First retrieves the video status to get the URL, then downloads the content. // Handles both regular URLs and data URLs (base64-encoded videos). func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2681,12 +2383,10 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte var latency time.Duration @@ -2698,7 +2398,7 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded contentType = videoResp.Videos[0].ContentType @@ -2728,11 +2428,11 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2747,19 +2447,17 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) providerResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ @@ -2769,8 +2467,6 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2808,19 +2504,6 @@ func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequ } } -func (provider *VertexProvider) getModelDeployment(key schemas.Key, model string) string { - if key.VertexKeyConfig == nil { - return model - } - - if key.VertexKeyConfig.Deployments != nil { - if deployment, ok := key.VertexKeyConfig.Deployments[model]; ok { - return deployment - } - } - return model -} - // BatchCreate is not supported by Vertex AI provider. func (provider *VertexProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) @@ -2880,25 +2563,13 @@ func (provider *VertexProvider) FileContent(_ *schemas.BifrostContext, _ []schem // CountTokens counts the number of tokens in the provided content using Vertex AI's countTokens endpoint. // Supports Gemini models with both text and image content. func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - var ( jsonBody []byte bifrostErr *schemas.BifrostError ) - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } @@ -2909,7 +2580,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiResponsesRequest(request) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2927,38 +2597,38 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" var completeURL string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { if region == "global" { completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/count-tokens:rawPredict", projectID) } else { completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", region, projectID, region) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - completeURL = getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":countTokens") + completeURL = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":countTokens") } if completeURL == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model/deployment: %s", deployment), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model: %s", request.Model)) } req := fasthttp.AcquireRequest() @@ -2980,11 +2650,11 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3012,14 +2682,10 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -3027,16 +2693,13 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch respOwned = false return &schemas.BifrostCountTokensResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := &anthropic.AnthropicCountTokensResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -3045,12 +2708,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3073,12 +2730,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := vertexResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3143,14 +2794,9 @@ func (provider *VertexProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3216,12 +2862,12 @@ func (provider *VertexProvider) Passthrough( tokenSource, err := getAuthTokenSource(key) if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3271,7 +2917,7 @@ func (provider *VertexProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -3285,9 +2931,6 @@ func (provider *VertexProvider) Passthrough( } bifrostResponse.ExtraFields.ProviderResponseHeaders = headers - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest - bifrostResponse.ExtraFields.ModelRequested = req.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3303,13 +2946,9 @@ func (provider *VertexProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3375,13 +3014,13 @@ func (provider *VertexProvider) PassthroughStream( if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3421,9 +3060,9 @@ func (provider *VertexProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { @@ -3438,9 +3077,7 @@ func (provider *VertexProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Set stream idle timeout from provider config. @@ -3453,11 +3090,7 @@ func (provider *VertexProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3469,9 +3102,9 @@ func (provider *VertexProvider) PassthroughStream( defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -3520,10 +3153,10 @@ func (provider *VertexProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } }() return ch, nil -} +} \ No newline at end of file diff --git a/core/providers/vertex/vertex_test.go b/core/providers/vertex/vertex_test.go index 03baf347fa..d754f33d22 100644 --- a/core/providers/vertex/vertex_test.go +++ b/core/providers/vertex/vertex_test.go @@ -27,7 +27,7 @@ func TestVertex(t *testing.T) { testConfig := llmtests.ComprehensiveTestConfig{ Provider: schemas.Vertex, - ChatModel: "google/gemini-2.0-flash-001", + ChatModel: "gemini-2.5-pro", PromptCachingModel: "claude-sonnet-4-5", VisionModel: "claude-sonnet-4-5", TextModel: "", // Vertex doesn't support text completion in newer models @@ -38,12 +38,12 @@ func TestVertex(t *testing.T) { ImageEditModel: "imagen-3.0-capability-001", VideoGenerationModel: "veo-3.1-generate-preview", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, End2EndToolCalling: true, diff --git a/core/providers/vllm/utils.go b/core/providers/vllm/utils.go index d2cefce786..ab6d694938 100644 --- a/core/providers/vllm/utils.go +++ b/core/providers/vllm/utils.go @@ -13,9 +13,6 @@ func HandleVLLMResponse[T any](responseBody []byte, response *T, requestBody []b return rawRequest, rawResponse, bifrostErr } if err := sonic.Unmarshal(responseBody, &errorResp); err == nil && errorResp.Error != nil && errorResp.Error.Message != "" { - errorResp.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: schemas.VLLM, - } return rawRequest, rawResponse, &errorResp } return rawRequest, rawResponse, nil diff --git a/core/providers/vllm/vllm.go b/core/providers/vllm/vllm.go index a6b7b1c48c..548c6a0dc7 100644 --- a/core/providers/vllm/vllm.go +++ b/core/providers/vllm/vllm.go @@ -76,9 +76,7 @@ func (provider *VLLMProvider) baseURLOrError(key schemas.Key) (string, *schemas. if u == "" { return "", providerUtils.NewBifrostOperationError( "no base URL configured: set vllm_key_config.url on the key", - nil, - provider.GetProviderKey(), - ) + nil) } return u, nil } @@ -246,9 +244,6 @@ func (provider *VLLMProvider) Responses(ctx *schemas.BifrostContext, key schemas return nil, err } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -314,12 +309,14 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( statusCode := resp.StatusCode() if statusCode != fasthttp.StatusOK { - return nil, nil, nil, nil, statusCode, latency, openai.ParseOpenAIError(resp, schemas.RerankRequest, provider.GetProviderKey(), request.Model) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, nil, nil, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -336,16 +333,12 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( // Rerank performs a rerank request to vLLM's API. func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVLLMRerankRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -358,6 +351,9 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke resolvedPath = "/" + resolvedPath } + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + responsePayload, rawRequest, rawResponse, responseBody, statusCode, latency, bifrostErr := provider.callVLLMRerankEndpoint(ctx, key, request, resolvedPath, jsonData) if bifrostErr != nil && !hasPathOverride && isRerankFallbackStatus(statusCode) { var fallbackLatency time.Duration @@ -365,7 +361,7 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke latency += fallbackLatency } if bifrostErr != nil { - return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, sendBackRawRequest, sendBackRawResponse) } returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments @@ -373,19 +369,16 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke if err != nil { return nil, providerUtils.EnrichError( ctx, - providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), + providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonData, responseBody, - provider.sendBackRawRequest, - provider.sendBackRawResponse, + sendBackRawRequest, + sendBackRawResponse, ) } // Keep requested model as the canonical model in Bifrost response. bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -440,7 +433,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Use centralized converter reqBody := openai.ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -496,9 +489,9 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -507,7 +500,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -527,9 +520,9 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -569,7 +562,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -586,11 +579,6 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p _, _, bifrostErr = HandleVLLMResponse(dataBytes, &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), dataBytes, false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)), responseChan, logger) return @@ -609,11 +597,8 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() diff --git a/core/providers/vllm/vllm_test.go b/core/providers/vllm/vllm_test.go index a9d7a1c17d..2f1d5b22c6 100644 --- a/core/providers/vllm/vllm_test.go +++ b/core/providers/vllm/vllm_test.go @@ -37,35 +37,35 @@ func TestVLLM(t *testing.T) { EmbeddingModel: embeddingModel, RerankModel: rerankModel, Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - CompleteEnd2End: true, - Embedding: true, - Rerank: rerankModel != "", - ListModels: true, - Reasoning: true, - SpeechSynthesis: false, - SpeechSynthesisStream: false, - Transcription: true, - TranscriptionStream: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, - ImageEditStream: false, - ImageVariation: false, - ImageVariationStream: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + Embedding: true, + Rerank: rerankModel != "", + ListModels: true, + Reasoning: true, + SpeechSynthesis: false, + SpeechSynthesisStream: false, + Transcription: true, + TranscriptionStream: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, + ImageEditStream: false, + ImageVariation: false, + ImageVariationStream: false, }, } diff --git a/core/providers/xai/errors.go b/core/providers/xai/errors.go index 78b22463e0..38a46888a8 100644 --- a/core/providers/xai/errors.go +++ b/core/providers/xai/errors.go @@ -15,7 +15,7 @@ type XAIErrorResponse struct { // ParseXAIError parses xAI-specific error responses. // xAI returns errors in format: {"code": "...", "error": "..."} // Unlike OpenAI which uses: {"error": {"message": "...", "type": "...", "code": "..."}} -func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseXAIError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse xAI error format var xaiErr XAIErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &xaiErr) @@ -35,10 +35,5 @@ func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, pro } } - // Set ExtraFields individually to preserve RawResponse from HandleProviderAPIError - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index 6fa52d39f3..118b8589bf 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -65,7 +65,7 @@ func (provider *XAIProvider) GetProviderKey() schemas.ModelProvider { // ListModels performs a list models request to xAI's API. func (provider *XAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } return openai.HandleOpenAIListModelsRequest( ctx, diff --git a/core/providers/xai/xai_test.go b/core/providers/xai/xai_test.go index 81c479bb7f..d7e0447d28 100644 --- a/core/providers/xai/xai_test.go +++ b/core/providers/xai/xai_test.go @@ -30,29 +30,29 @@ func TestXAI(t *testing.T) { TextModel: "grok-3", VisionModel: "grok-4-1-fast-reasoning", EmbeddingModel: "", // XAI doesn't support embedding - ImageGenerationModel: "grok-2-image", + ImageGenerationModel: "grok-imagine-image", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - ImageGeneration: true, - ImageGenerationStream: false, - FileBase64: false, - FileURL: false, - MultipleImages: true, - CompleteEnd2End: true, - Reasoning: true, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + ImageGeneration: true, + ImageGenerationStream: false, + FileBase64: false, + FileURL: false, + MultipleImages: true, + CompleteEnd2End: true, + Reasoning: true, + Embedding: false, + ListModels: true, }, } diff --git a/core/schemas/account.go b/core/schemas/account.go index ceaeb2de8a..3dfb16a0b4 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -1,7 +1,12 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "context" +import ( + "context" + "fmt" + "slices" + "strings" +) type KeyStatusType string @@ -10,26 +15,172 @@ const ( KeyStatusListModelsFailed KeyStatusType = "list_models_failed" ) +// WhiteList is a list of values that are allowed to be used. +// Semantics: +// - "*" (alone) means all values are allowed. +// - Empty list means nothing is allowed. +// - Non-empty list (without "*") means only the listed values are allowed. +// +// This type is used generically for any field that needs whitelist behavior +// (e.g., allowed models, allowed tools). +type WhiteList []string + +// Contains reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) Contains(value string) bool { + return slices.ContainsFunc(wl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsAllowed reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) IsAllowed(value string) bool { + return wl.IsUnrestricted() || wl.Contains(value) +} + +// IsEmpty reports whether the whitelist has no entries. +func (wl WhiteList) IsEmpty() bool { + return len(wl) == 0 +} + +// IsUnrestricted reports whether the whitelist contains only "*", +// meaning all values are allowed. +func (wl WhiteList) IsUnrestricted() bool { + return len(wl) == 1 && wl[0] == "*" +} + +// IsRestricted reports whether the whitelist contains entries other than "*", +// meaning only the listed values are allowed. +func (wl WhiteList) IsRestricted() bool { + return !wl.IsUnrestricted() +} + +// Validate checks that the whitelist is well-formed. +// Returns an error if "*" is present alongside other values, or if there are duplicate entries. +func (wl WhiteList) Validate() error { + if wl.Contains("*") && len(wl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the whitelist") + } + seen := make(map[string]struct{}, len(wl)) + for _, v := range wl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in whitelist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + +// BlackList is a list of values that are denied. +// Semantics: +// - "*" (alone) means all values are blocked. +// - Empty list means nothing is blocked. +// - Non-empty list (without "*") means only the listed values are blocked. +type BlackList []string + +func (bl BlackList) Contains(value string) bool { + return slices.ContainsFunc(bl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsBlocked reports whether value is blocked. +func (bl BlackList) IsBlocked(value string) bool { + return bl.IsBlockAll() || bl.Contains(value) +} + +// IsEmpty reports whether the blacklist has no entries (nothing is blocked). +func (bl BlackList) IsEmpty() bool { + return len(bl) == 0 +} + +// IsBlockAll reports whether the blacklist contains "*", meaning all values are blocked. +func (bl BlackList) IsBlockAll() bool { + return len(bl) == 1 && bl[0] == "*" +} + +// Validate checks that the blacklist is well-formed. +func (bl BlackList) Validate() error { + if bl.Contains("*") && len(bl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the blacklist") + } + seen := make(map[string]struct{}, len(bl)) + for _, v := range bl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in blacklist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { - ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key) - Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) - Value EnvVar `json:"value"` // The actual API key value - Models []string `json:"models"` // List of models this key can access - BlacklistedModels []string `json:"blacklisted_models"` // List of models this key cannot access - Weight float64 `json:"weight"` // Weight for load balancing between multiple keys - AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration - VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration - BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration - HuggingFaceKeyConfig *HuggingFaceKeyConfig `json:"huggingface_key_config,omitempty"` // Hugging Face-specific key configuration - ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration - VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration - Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) - UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true) - ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection - Status KeyStatusType `json:"status,omitempty"` // Status of key - Description string `json:"description,omitempty"` // Description of key + ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key) + Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) + Value EnvVar `json:"value"` // The actual API key value + Models WhiteList `json:"models"` // List of models this key can access + BlacklistedModels BlackList `json:"blacklisted_models"` // List of models this key cannot access + Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + Aliases KeyAliases `json:"aliases,omitempty"` // Mapping of model identifiers to inference profiles + AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration + VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration + BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration + VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration + ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration + OllamaKeyConfig *OllamaKeyConfig `json:"ollama_key_config,omitempty"` // Ollama-specific key configuration + SGLKeyConfig *SGLKeyConfig `json:"sgl_key_config,omitempty"` // SGLang-specific key configuration + Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) + UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true) + ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection + Status KeyStatusType `json:"status,omitempty"` // Status of key + Description string `json:"description,omitempty"` // Description of key +} + +type KeyAliases map[string]string + +func (ka KeyAliases) Validate() error { + seen := make(map[string]struct{}, len(ka)) + for from, to := range ka { + if strings.TrimSpace(from) == "" { + return fmt.Errorf("alias source cannot be empty") + } + if strings.TrimSpace(to) == "" { + return fmt.Errorf("alias target for %q cannot be empty", from) + } + if strings.TrimSpace(from) != from { + return fmt.Errorf("alias source %q cannot have leading or trailing whitespace", from) + } + if strings.TrimSpace(to) != to { + return fmt.Errorf("alias target for %q cannot have leading or trailing whitespace", from) + } + normalized := strings.ToLower(from) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate alias source %q (case-insensitive)", from) + } + seen[normalized] = struct{}{} + } + return nil +} + +func (ka KeyAliases) Resolve(model string) string { + if ka == nil { + return model + } + if alias, ok := ka[model]; ok { + return alias + } + // Fall back to case-insensitive lookup for consistency with WhiteList.Contains + for k, v := range ka { + if strings.EqualFold(k, model) { + return v + } + } + return model } type AzureAuthType string @@ -42,9 +193,8 @@ const ( // AzureKeyConfig represents the Azure-specific configuration. // It contains Azure-specific settings required for service access and deployment management. type AzureKeyConfig struct { - Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names - APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" + Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL + APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" ClientID *EnvVar `json:"client_id,omitempty"` // Azure client ID for authentication ClientSecret *EnvVar `json:"client_secret,omitempty"` // Azure client secret for authentication @@ -55,11 +205,10 @@ type AzureKeyConfig struct { // VertexKeyConfig represents the Vertex-specific configuration. // It contains Vertex-specific settings required for authentication and service access. type VertexKeyConfig struct { - ProjectID EnvVar `json:"project_id"` - ProjectNumber EnvVar `json:"project_number"` - Region EnvVar `json:"region"` - AuthCredentials EnvVar `json:"auth_credentials"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles + ProjectID EnvVar `json:"project_id"` + ProjectNumber EnvVar `json:"project_number"` + Region EnvVar `json:"region"` + AuthCredentials EnvVar `json:"auth_credentials"` } // NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string. @@ -90,21 +239,12 @@ type BedrockKeyConfig struct { ExternalID *EnvVar `json:"external_id,omitempty"` RoleSessionName *EnvVar `json:"session_name,omitempty"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles - BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations + BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations } // NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings. // To use Bedrock API Key authentication, set Value in Key struct instead. -type HuggingFaceKeyConfig struct { - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to deployment names -} - -type ReplicateKeyConfig struct { - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to deployment names -} - // VLLMKeyConfig represents the vLLM-specific key configuration. // It allows each key to target a different vLLM server URL and model name, // enabling per-key routing and round-robin load balancing across multiple vLLM instances. @@ -113,6 +253,26 @@ type VLLMKeyConfig struct { ModelName string `json:"model_name"` // Exact model name served on this VLLM instance (used for key selection) } +// ReplicateKeyConfig represents the Replicate-specific key configuration. +// It contains Replicate-specific settings required for authentication and service access. +type ReplicateKeyConfig struct { + UseDeploymentsEndpoint bool `json:"use_deployments_endpoint"` // Whether to use the deployments endpoint instead of the models endpoint +} + +// OllamaKeyConfig represents the Ollama-specific key configuration. +// It allows each key to target a different Ollama server URL, +// enabling per-key routing and round-robin load balancing across multiple Ollama instances. +type OllamaKeyConfig struct { + URL EnvVar `json:"url"` // Ollama server base URL (required, supports env. prefix) +} + +// SGLKeyConfig represents the SGLang-specific key configuration. +// It allows each key to target a different SGLang server URL, +// enabling per-key routing and round-robin load balancing across multiple SGLang instances. +type SGLKeyConfig struct { + URL EnvVar `json:"url"` // SGLang server base URL (required, supports env. prefix) +} + // Account defines the interface for managing provider accounts and their configurations. // It provides methods to access provider-specific settings, API keys, and configurations. type Account interface { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index c00c5c7078..04cca3915a 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -161,35 +161,47 @@ type BifrostContextKey string // BifrostContextKeyRequestType is a context key for the request type. const ( - BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) - BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string - BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) - BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) - BifrostContextKeyRequestID BifrostContextKey = "request-id" // string - BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string - BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceUserID BifrostContextKey = "bifrost-governance-user-id" // string (to store the user ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) - BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) - BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string - BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string + BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) + BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string + BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) + BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) + BifrostContextKeyRequestID BifrostContextKey = "request-id" // string + BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string + BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct + + // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). + // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. + // Request context filtering takes priority over client config - context can override client exclusions. + MCPContextKeyIncludeClients BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyIncludeTools BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) + + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitID BifrostContextKey = "bifrost-governance-business-unit-id" // string (to store the business unit ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitName BifrostContextKey = "bifrost-governance-business-unit-name" // string (to store the business unit name (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptName BifrostContextKey = "bifrost-selected-prompt-name" // string (display name of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptVersion BifrostContextKey = "bifrost-selected-prompt-version" // string (numeric version as string, e.g. "3" (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptID BifrostContextKey = "bifrost-selected-prompt-id" // string (id of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) + BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) + BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string + BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" - BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool - BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool + BifrostContextKeyChangeRequestType BifrostContextKey = "bifrost-change-request-type" // RequestType (set by plugins to trigger request type conversion in core, e.g. text->chat or chat->responses) + BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool (per-request override — read by bifrost.go, never overwritten) + BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool (per-request override — read by bifrost.go, never overwritten) BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostMCPAgentOriginalRequestID BifrostContextKey = "bifrost-mcp-agent-original-request-id" // string (to store the original request ID for MCP agent mode) @@ -202,32 +214,51 @@ const ( BifrostContextKeyStreamStartTime BifrostContextKey = "bifrost-stream-start-time" // time.Time (start time for streaming TTFT calculation - set by bifrost) BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost) BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers) - BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) + BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func([]PluginLogEntry) (callback to complete trace after streaming, receives transport plugin logs - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) - BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) + BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost) + BifrostContextKeyMCPUserID BifrostContextKey = "bifrost-mcp-user-id" // string (per-user OAuth user identifier from X-Bf-User-Id header) + BifrostContextKeyOAuthRedirectURI BifrostContextKey = "bifrost-oauth-redirect-uri" // string (OAuth callback URL, e.g. https://host/api/oauth/callback - set by HTTP middleware) + BifrostContextKeyIsMCPGateway BifrostContextKey = "bifrost-is-mcp-gateway" // bool (true when request is being handled via the MCP gateway path) + BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost) + BifrostContextKeyPromptsPluginName BifrostContextKey = "prompts-plugin-name" // string (name of the prompts plugin to use - set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyIsEnterprise BifrostContextKey = "is-enterprise" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyAvailableProviders BifrostContextKey = "available-providers" // []ModelProvider (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyRawRequestResponseForLogging BifrostContextKey = "bifrost-raw-request-response-for-logging" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyStoreRawRequestResponse BifrostContextKey = "bifrost-store-raw-request-response" // bool (per-request override — read by bifrost.go, never overwritten) + BifrostContextKeyCaptureRawRequest BifrostContextKey = "bifrost-capture-raw-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when providers should capture raw request bytes + BifrostContextKeyCaptureRawResponse BifrostContextKey = "bifrost-capture-raw-response" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when providers should capture raw response bytes + BifrostContextKeyDropRawRequestFromClient BifrostContextKey = "bifrost-drop-raw-request-from-client" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when raw request should be stripped from the client-facing response + BifrostContextKeyDropRawResponseFromClient BifrostContextKey = "bifrost-drop-raw-response-from-client" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when raw response should be stripped from the client-facing response + BifrostContextKeyShouldStoreRawInLogs BifrostContextKey = "bifrost-should-store-raw-in-logs" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when raw request/response should be persisted in log records BifrostContextKeyRetryDBFetch BifrostContextKey = "bifrost-retry-db-fetch" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyIsCustomProvider BifrostContextKey = "bifrost-is-custom-provider" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyHTTPRequestType BifrostContextKey = "bifrost-http-request-type" // RequestType (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries + BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) + BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request + BifrostContextKeyParentRequestID BifrostContextKey = "bifrost-parent-request-id" // string (parent linkage for grouped request logs like realtime turns) + BifrostContextKeyRealtimeSessionID BifrostContextKey = "bifrost-realtime-session-id" // string + BifrostContextKeyRealtimeProviderSessionID BifrostContextKey = "bifrost-realtime-provider-session-id" // string + BifrostContextKeyRealtimeSource BifrostContextKey = "bifrost-realtime-source" // string ("ei" or "lm") + BifrostContextKeyRealtimeEventType BifrostContextKey = "bifrost-realtime-event-type" // string BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway) BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeySCIMClaims BifrostContextKey = "scim_claims" - BifrostContextKeyUserID BifrostContextKey = "user_id" + BifrostContextKeyUserID BifrostContextKey = "bifrost-user-id" // string (to store the user ID (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) + BifrostContextKeyUserName BifrostContextKey = "bifrost-user-name" // string (to store the user name (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) BifrostContextKeyTargetUserID BifrostContextKey = "target_user_id" BifrostContextKeyIsAzureUserAgent BifrostContextKey = "bifrost-is-azure-user-agent" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an Azure user agent (only used in gateway) BifrostContextKeyVideoOutputRequested BifrostContextKey = "bifrost-video-output-requested" BifrostContextKeyValidateKeys BifrostContextKey = "bifrost-validate-keys" // bool (triggers additional key validation during provider add/update) BifrostContextKeyProviderResponseHeaders BifrostContextKey = "bifrost-provider-response-headers" // map[string]string (set by provider handlers for response header forwarding) + BifrostContextKeyMCPAddedTools BifrostContextKey = "bifrost-mcp-added-tools" // []string (set by bifrost - DO NOT SET THIS MANUALLY)) - list of tools added to the request by MCP, all the tool are in the format "clientName-toolName" BifrostContextKeyLargePayloadMode BifrostContextKey = "bifrost-large-payload-mode" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) indicates large payload streaming mode is active BifrostContextKeyLargePayloadReader BifrostContextKey = "bifrost-large-payload-reader" // io.Reader (set by bifrost - DO NOT SET THIS MANUALLY)) upstream reader for large payloads BifrostContextKeyLargePayloadContentLength BifrostContextKey = "bifrost-large-payload-content-length" // int (set by bifrost - DO NOT SET THIS MANUALLY)) content length for large payloads @@ -248,7 +279,13 @@ const ( BifrostContextKeySSEReaderFactory BifrostContextKey = "bifrost-sse-reader-factory" // *providerUtils.SSEReaderFactory (set by enterprise — replaces default bufio.Scanner SSE readers with streaming readers) BifrostContextKeySessionID BifrostContextKey = "bifrost-session-id" // string session ID for the request (session stickiness) BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness) + BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client) BifrostContextKeyMCPLogID BifrostContextKey = "bifrost-mcp-log-id" // string (unique UUID for each MCP tool log entry - set per goroutine by agent executor - DO NOT SET THIS MANUALLY) + BifrostContextKeyCompatConvertTextToChat BifrostContextKey = "bifrost-compat-convert-text-to-chat" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatConvertChatToResponses BifrostContextKey = "bifrost-compat-convert-chat-to-responses" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldDropParams BifrostContextKey = "bifrost-compat-should-drop-params" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldConvertParams BifrostContextKey = "bifrost-compat-should-convert-params" // bool (per-request override from x-bf-compat header) + BifrostContextKeyAttemptTrail BifrostContextKey = "bifrost-attempt-trail" // []KeyAttemptRecord (set by bifrost - DO NOT SET THIS MANUALLY) - per-attempt key selection history ) const ( @@ -264,6 +301,18 @@ const ( RoutingEngineLoadbalancing = "loadbalancing" ) +// KeyAttemptRecord captures the outcome of a single request attempt within executeRequestWithRetries. +// One record is appended per attempt regardless of whether the key changed between attempts. +// FailReason is supplementary retry metadata: it is populated only when another retry will be +// attempted (i.e. a non-terminal attempt), and is nil on any terminal attempt — including success, +// non-retryable failure, or a retryable error when no retries remain. +type KeyAttemptRecord struct { + Attempt int `json:"attempt"` + KeyID string `json:"key_id"` + KeyName string `json:"key_name"` + FailReason *string `json:"fail_reason,omitempty"` +} + // RoutingEngineLogEntry represents a log entry from a routing engine // format: [timestamp] [engine] - message type RoutingEngineLogEntry struct { @@ -272,6 +321,27 @@ type RoutingEngineLogEntry struct { Timestamp int64 // Unix milliseconds } +// PluginLogEntry represents a structured log entry emitted by a plugin via ctx.Log(). +type PluginLogEntry struct { + PluginName string `json:"plugin_name"` + Level LogLevel `json:"level"` + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` // Unix milliseconds +} + +// GroupPluginLogsByName groups a flat slice of plugin log entries by plugin name. +// Returns nil if the input is empty. +func GroupPluginLogsByName(logs []PluginLogEntry) map[string][]PluginLogEntry { + if len(logs) == 0 { + return nil + } + grouped := make(map[string][]PluginLogEntry, min(len(logs), 4)) + for _, entry := range logs { + grouped[entry.PluginName] = append(grouped[entry.PluginName], entry) + } + return grouped +} + // NOTE: for custom plugin implementation dealing with streaming short circuit, // make sure to mark BifrostContextKeyStreamEndIndicator as true at the end of the stream. @@ -542,6 +612,10 @@ func (br *BifrostRequest) SetModel(model string) { br.ImageVariationRequest.Model = model case br.VideoGenerationRequest != nil: br.VideoGenerationRequest.Model = model + case br.BatchCreateRequest != nil: + if br.BatchCreateRequest.Model != nil { + br.BatchCreateRequest.Model = new(model) + } } } @@ -784,6 +858,213 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { return &BifrostResponseExtraFields{} } +func (r *BifrostResponse) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if r == nil { + return + } + resolvedModel := resolvedModelUsed + if resolvedModel == "" { + resolvedModel = originalModelRequested + } + switch { + case r.ListModelsResponse != nil: + r.ListModelsResponse.ExtraFields.RequestType = requestType + r.ListModelsResponse.ExtraFields.Provider = provider + r.ListModelsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ListModelsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TextCompletionResponse != nil: + r.TextCompletionResponse.ExtraFields.RequestType = requestType + r.TextCompletionResponse.ExtraFields.Provider = provider + r.TextCompletionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TextCompletionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ChatResponse != nil: + r.ChatResponse.ExtraFields.RequestType = requestType + r.ChatResponse.ExtraFields.Provider = provider + r.ChatResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ChatResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesResponse != nil: + r.ResponsesResponse.ExtraFields.RequestType = requestType + r.ResponsesResponse.ExtraFields.Provider = provider + r.ResponsesResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesStreamResponse != nil: + r.ResponsesStreamResponse.ExtraFields.RequestType = requestType + r.ResponsesStreamResponse.ExtraFields.Provider = provider + r.ResponsesStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.CountTokensResponse != nil: + r.CountTokensResponse.ExtraFields.RequestType = requestType + r.CountTokensResponse.ExtraFields.Provider = provider + r.CountTokensResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.CountTokensResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.EmbeddingResponse != nil: + r.EmbeddingResponse.ExtraFields.RequestType = requestType + r.EmbeddingResponse.ExtraFields.Provider = provider + r.EmbeddingResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.EmbeddingResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.RerankResponse != nil: + r.RerankResponse.ExtraFields.RequestType = requestType + r.RerankResponse.ExtraFields.Provider = provider + r.RerankResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.RerankResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechResponse != nil: + r.SpeechResponse.ExtraFields.RequestType = requestType + r.SpeechResponse.ExtraFields.Provider = provider + r.SpeechResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechStreamResponse != nil: + r.SpeechStreamResponse.ExtraFields.RequestType = requestType + r.SpeechStreamResponse.ExtraFields.Provider = provider + r.SpeechStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionResponse != nil: + r.TranscriptionResponse.ExtraFields.RequestType = requestType + r.TranscriptionResponse.ExtraFields.Provider = provider + r.TranscriptionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionStreamResponse != nil: + r.TranscriptionStreamResponse.ExtraFields.RequestType = requestType + r.TranscriptionStreamResponse.ExtraFields.Provider = provider + r.TranscriptionStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationResponse != nil: + r.ImageGenerationResponse.ExtraFields.RequestType = requestType + r.ImageGenerationResponse.ExtraFields.Provider = provider + r.ImageGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationStreamResponse != nil: + r.ImageGenerationStreamResponse.ExtraFields.RequestType = requestType + r.ImageGenerationStreamResponse.ExtraFields.Provider = provider + r.ImageGenerationStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoGenerationResponse != nil: + r.VideoGenerationResponse.ExtraFields.RequestType = requestType + r.VideoGenerationResponse.ExtraFields.Provider = provider + r.VideoGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDownloadResponse != nil: + r.VideoDownloadResponse.ExtraFields.RequestType = requestType + r.VideoDownloadResponse.ExtraFields.Provider = provider + r.VideoDownloadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDownloadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoListResponse != nil: + r.VideoListResponse.ExtraFields.RequestType = requestType + r.VideoListResponse.ExtraFields.Provider = provider + r.VideoListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDeleteResponse != nil: + r.VideoDeleteResponse.ExtraFields.RequestType = requestType + r.VideoDeleteResponse.ExtraFields.Provider = provider + r.VideoDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileUploadResponse != nil: + r.FileUploadResponse.ExtraFields.RequestType = requestType + r.FileUploadResponse.ExtraFields.Provider = provider + r.FileUploadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileUploadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileListResponse != nil: + r.FileListResponse.ExtraFields.RequestType = requestType + r.FileListResponse.ExtraFields.Provider = provider + r.FileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileRetrieveResponse != nil: + r.FileRetrieveResponse.ExtraFields.RequestType = requestType + r.FileRetrieveResponse.ExtraFields.Provider = provider + r.FileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileDeleteResponse != nil: + r.FileDeleteResponse.ExtraFields.RequestType = requestType + r.FileDeleteResponse.ExtraFields.Provider = provider + r.FileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileContentResponse != nil: + r.FileContentResponse.ExtraFields.RequestType = requestType + r.FileContentResponse.ExtraFields.Provider = provider + r.FileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCreateResponse != nil: + r.BatchCreateResponse.ExtraFields.RequestType = requestType + r.BatchCreateResponse.ExtraFields.Provider = provider + r.BatchCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchListResponse != nil: + r.BatchListResponse.ExtraFields.RequestType = requestType + r.BatchListResponse.ExtraFields.Provider = provider + r.BatchListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchRetrieveResponse != nil: + r.BatchRetrieveResponse.ExtraFields.RequestType = requestType + r.BatchRetrieveResponse.ExtraFields.Provider = provider + r.BatchRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCancelResponse != nil: + r.BatchCancelResponse.ExtraFields.RequestType = requestType + r.BatchCancelResponse.ExtraFields.Provider = provider + r.BatchCancelResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCancelResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchDeleteResponse != nil: + r.BatchDeleteResponse.ExtraFields.RequestType = requestType + r.BatchDeleteResponse.ExtraFields.Provider = provider + r.BatchDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchResultsResponse != nil: + r.BatchResultsResponse.ExtraFields.RequestType = requestType + r.BatchResultsResponse.ExtraFields.Provider = provider + r.BatchResultsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchResultsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerCreateResponse != nil: + r.ContainerCreateResponse.ExtraFields.RequestType = requestType + r.ContainerCreateResponse.ExtraFields.Provider = provider + r.ContainerCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerListResponse != nil: + r.ContainerListResponse.ExtraFields.RequestType = requestType + r.ContainerListResponse.ExtraFields.Provider = provider + r.ContainerListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerRetrieveResponse != nil: + r.ContainerRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerRetrieveResponse.ExtraFields.Provider = provider + r.ContainerRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerDeleteResponse != nil: + r.ContainerDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerDeleteResponse.ExtraFields.Provider = provider + r.ContainerDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileCreateResponse != nil: + r.ContainerFileCreateResponse.ExtraFields.RequestType = requestType + r.ContainerFileCreateResponse.ExtraFields.Provider = provider + r.ContainerFileCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileListResponse != nil: + r.ContainerFileListResponse.ExtraFields.RequestType = requestType + r.ContainerFileListResponse.ExtraFields.Provider = provider + r.ContainerFileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileRetrieveResponse != nil: + r.ContainerFileRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerFileRetrieveResponse.ExtraFields.Provider = provider + r.ContainerFileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileContentResponse != nil: + r.ContainerFileContentResponse.ExtraFields.RequestType = requestType + r.ContainerFileContentResponse.ExtraFields.Provider = provider + r.ContainerFileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileDeleteResponse != nil: + r.ContainerFileDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerFileDeleteResponse.ExtraFields.Provider = provider + r.ContainerFileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.PassthroughResponse != nil: + r.PassthroughResponse.ExtraFields.RequestType = requestType + r.PassthroughResponse.ExtraFields.Provider = provider + r.PassthroughResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.PassthroughResponse.ExtraFields.ResolvedModelUsed = resolvedModel + } +} + // BifrostMCPResponse is the response struct for all MCP responses. // only ONE of the following fields should be set: // - ChatMessage @@ -796,18 +1077,19 @@ type BifrostMCPResponse struct { // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { - RequestType RequestType `json:"request_type"` - Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - ModelDeployment string `json:"model_deployment,omitempty"` // only present for providers which use model deployments (e.g. Azure, Bedrock) - Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) - ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` - ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) + RequestType RequestType `json:"request_type"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) + Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` + ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` // params dropped by the compat plugin based on model catalog + ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) } type BifrostMCPResponseExtraFields struct { @@ -895,6 +1177,20 @@ type BifrostError struct { ExtraFields BifrostErrorExtraFields `json:"extra_fields"` } +func (e *BifrostError) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if e == nil { + return + } + e.ExtraFields.RequestType = requestType + e.ExtraFields.Provider = provider + e.ExtraFields.OriginalModelRequested = originalModelRequested + if resolvedModelUsed != "" { + e.ExtraFields.ResolvedModelUsed = resolvedModelUsed + } else { + e.ExtraFields.ResolvedModelUsed = originalModelRequested + } +} + // StreamControl represents stream control options. type StreamControl struct { LogError *bool `json:"log_error,omitempty"` // Optional: Controls logging of error @@ -968,11 +1264,14 @@ func (e *ErrorField) UnmarshalJSON(data []byte) error { // BifrostErrorExtraFields contains additional fields in an error response. type BifrostErrorExtraFields struct { - Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - RequestType RequestType `json:"request_type,omitempty"` - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` + RequestType RequestType `json:"request_type,omitempty"` + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` + KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication } diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index be70991155..320fa15ac1 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -81,7 +81,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -114,7 +115,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -150,7 +152,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -167,13 +170,15 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion SystemFingerprint: cr.SystemFingerprint, Usage: cr.Usage, ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, }, } } @@ -726,7 +731,6 @@ type AdditionalPropertiesStruct struct { // MarshalJSON implements custom JSON marshalling for AdditionalPropertiesStruct. // It marshals either AdditionalPropertiesBool or AdditionalPropertiesMap based on which is set. func (a AdditionalPropertiesStruct) MarshalJSON() ([]byte, error) { - // if both are set, return an error if a.AdditionalPropertiesBool != nil && a.AdditionalPropertiesMap != nil { return nil, fmt.Errorf("both AdditionalPropertiesBool and AdditionalPropertiesMap are set; only one should be non-nil") @@ -1490,7 +1494,7 @@ type BifrostLLMUsage struct { CompletionTokens int `json:"completion_tokens,omitempty"` CompletionTokensDetails *ChatCompletionTokensDetails `json:"completion_tokens_details,omitempty"` TotalTokens int `json:"total_tokens"` - Cost *BifrostCost `json:"cost,omitempty"` //Only for the providers which support cost calculation + Cost *BifrostCost `json:"cost,omitempty"` // Only for the providers which support cost calculation } type ChatPromptTokensDetails struct { diff --git a/core/schemas/context.go b/core/schemas/context.go index 1ff4663eae..68ac7c435e 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -24,6 +24,29 @@ var reservedKeys = []any{ BifrostContextKeySkipKeySelection, BifrostContextKeyURLPath, BifrostContextKeyDeferTraceCompletion, + BifrostContextKeyAttemptTrail, +} + +// pluginLogStore holds plugin log entries accumulated during request processing. +// It is shared between the root BifrostContext and all scoped contexts derived from it. +// Uses a flat slice (not map) to minimize heap allocations. +type pluginLogStore struct { + mu sync.Mutex + logs []PluginLogEntry +} + +// pluginLogStorePool pools pluginLogStore instances to reduce per-request allocations. +var pluginLogStorePool = sync.Pool{ + New: func() any { + return &pluginLogStore{logs: make([]PluginLogEntry, 0, 8)} + }, +} + +// pluginScopePool pools BifrostContext instances used as scoped plugin contexts. +var pluginScopePool = sync.Pool{ + New: func() any { + return &BifrostContext{} + }, } // BifrostContext is a custom context.Context implementation that tracks user-set values. @@ -40,6 +63,11 @@ type BifrostContext struct { userValues map[any]any valuesMu sync.RWMutex blockRestrictedWrites atomic.Bool + + // Plugin scoping fields + pluginScope *string // Non-nil when this is a scoped plugin context + pluginLogs atomic.Pointer[pluginLogStore] // Shared log store; lazily initialized on root, shared by scoped contexts + valueDelegate *BifrostContext // For scoped contexts: delegate Value/SetValue to this root context } // NewBifrostContext creates a new BifrostContext with the given parent context and deadline. @@ -166,8 +194,12 @@ func (bc *BifrostContext) cancel(err error) { } // Deadline returns the deadline for this context. +// For scoped contexts, delegates to the root context. // If both this context and the parent have deadlines, the earlier one is returned. func (bc *BifrostContext) Deadline() (time.Time, bool) { + if bc.valueDelegate != nil { + return bc.valueDelegate.Deadline() + } parentDeadline, parentHasDeadline := bc.parent.Deadline() if !bc.hasDeadline && !parentHasDeadline { @@ -195,16 +227,24 @@ func (bc *BifrostContext) Done() <-chan struct{} { } // Err returns the error explaining why the context was cancelled. +// For scoped contexts, delegates to the root context. // Returns nil if the context has not been cancelled. func (bc *BifrostContext) Err() error { + if bc.valueDelegate != nil { + return bc.valueDelegate.Err() + } bc.errMu.RLock() defer bc.errMu.RUnlock() return bc.err } // Value returns the value associated with the key. -// It first checks the internal userValues map, then delegates to the parent context. +// For scoped contexts, delegates to the root context via valueDelegate. +// Otherwise checks the internal userValues map, then delegates to the parent context. func (bc *BifrostContext) Value(key any) any { + if bc.valueDelegate != nil { + return bc.valueDelegate.Value(key) + } bc.valuesMu.RLock() if val, ok := bc.userValues[key]; ok { bc.valuesMu.RUnlock() @@ -212,12 +252,21 @@ func (bc *BifrostContext) Value(key any) any { } bc.valuesMu.RUnlock() + if bc.parent == nil { + return nil + } + return bc.parent.Value(key) } // SetValue sets a value in the internal userValues map. +// For scoped contexts, delegates to the root context via valueDelegate. // This is thread-safe and can be called concurrently. func (bc *BifrostContext) SetValue(key, value any) { + if bc.valueDelegate != nil { + bc.valueDelegate.SetValue(key, value) + return + } // Check if the key is a reserved key if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { // we silently drop writes for these reserved keys @@ -232,7 +281,12 @@ func (bc *BifrostContext) SetValue(key, value any) { } // ClearValue clears a value from the internal userValues map. +// For scoped contexts, delegates to the root context via valueDelegate. func (bc *BifrostContext) ClearValue(key any) { + if bc.valueDelegate != nil { + bc.valueDelegate.ClearValue(key) + return + } // Check if the key is a reserved key if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { // we silently drop writes for these reserved keys @@ -245,8 +299,12 @@ func (bc *BifrostContext) ClearValue(key any) { } } -// GetAndSetValue gets a value from the internal userValues map and sets it +// GetAndSetValue gets a value from the internal userValues map and sets it. +// For scoped contexts, delegates to the root context via valueDelegate. func (bc *BifrostContext) GetAndSetValue(key any, value any) any { + if bc.valueDelegate != nil { + return bc.valueDelegate.GetAndSetValue(key, value) + } bc.valuesMu.Lock() defer bc.valuesMu.Unlock() // Check if the key is a reserved key @@ -340,3 +398,104 @@ func AppendToContextList[T any](ctx *BifrostContext, key BifrostContextKey, valu } ctx.SetValue(key, append(existingValues, value)) } + +// WithPluginScope returns a lightweight scoped BifrostContext from the pool. +// The scoped context shares the root's pluginLogs store and delegates all +// Value/SetValue operations to the root context. +// Call ReleasePluginScope() when done to return the scoped context to the pool. +func (bc *BifrostContext) WithPluginScope(name *string) *BifrostContext { + // Lazily initialize the plugin log store on the root context (CAS to avoid race) + if bc.pluginLogs.Load() == nil { + newStore := pluginLogStorePool.Get().(*pluginLogStore) + if !bc.pluginLogs.CompareAndSwap(nil, newStore) { + // Another goroutine initialized first — return unused store to pool + pluginLogStorePool.Put(newStore) + } + } + + scoped := pluginScopePool.Get().(*BifrostContext) + scoped.parent = bc.parent + scoped.done = bc.done + scoped.pluginScope = name + scoped.pluginLogs.Store(bc.pluginLogs.Load()) + scoped.valueDelegate = bc + return scoped +} + +// ReleasePluginScope returns a scoped context to the pool. +// Safe no-op if called on a non-scoped context. +// Do not use the scoped context after calling this method. +func (bc *BifrostContext) ReleasePluginScope() { + if bc.valueDelegate == nil { + return // not a scoped context + } + bc.parent = nil + bc.done = nil + bc.pluginScope = nil + bc.pluginLogs.Store(nil) + bc.valueDelegate = nil + pluginScopePool.Put(bc) +} + +// Log appends a structured log entry for the current plugin scope. +// No-op if the context is not scoped to a plugin or has no log store. +func (bc *BifrostContext) Log(level LogLevel, msg string) { + store := bc.pluginLogs.Load() + if bc.pluginScope == nil || store == nil { + return + } + store.mu.Lock() + store.logs = append(store.logs, PluginLogEntry{ + PluginName: *bc.pluginScope, + Level: level, + Message: msg, + Timestamp: time.Now().UnixMilli(), + }) + store.mu.Unlock() +} + +// GetPluginLogs returns a deep copy of all accumulated plugin log entries. +// Thread-safe. Returns nil if no logs have been recorded. +func (bc *BifrostContext) GetPluginLogs() []PluginLogEntry { + store := bc.pluginLogs.Load() + if store == nil { + return nil + } + store.mu.Lock() + defer store.mu.Unlock() + if len(store.logs) == 0 { + return nil + } + copied := make([]PluginLogEntry, len(store.logs)) + copy(copied, store.logs) + return copied +} + +// DrainPluginLogs transfers ownership of the plugin log slice to the caller. +// The internal log store is returned to the pool after draining. +// Returns nil if no logs have been recorded. +// This should be called once on the root context after all plugin hooks have completed. +func (bc *BifrostContext) DrainPluginLogs() []PluginLogEntry { + if bc.valueDelegate != nil { + return nil // scoped contexts must not drain the shared log store + } + store := bc.pluginLogs.Load() + if store == nil { + return nil + } + bc.pluginLogs.Store(nil) + + store.mu.Lock() + logs := store.logs + // Reset with fresh pre-allocated slice before returning to pool + store.logs = make([]PluginLogEntry, 0, 8) + store.mu.Unlock() + + // Return the store to the pool for reuse + pluginLogStorePool.Put(store) + + if len(logs) == 0 { + return nil + } + return logs +} diff --git a/core/schemas/context_test.go b/core/schemas/context_test.go index 75e52e2061..108da2ced0 100644 --- a/core/schemas/context_test.go +++ b/core/schemas/context_test.go @@ -207,3 +207,125 @@ func TestNewBifrostContext_NilParent(t *testing.T) { t.Errorf("Cancelled context should have Canceled error, got %v", ctx.Err()) } } + +// Plugin logging tests + +func TestPluginLog_NoScopeIsNoop(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + ctx.Log(LogLevelInfo, "should be ignored") + logs := ctx.GetPluginLogs() + if logs != nil { + t.Errorf("expected nil logs without plugin scope, got %v", logs) + } +} + +func TestPluginLog_SinglePlugin(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + name := "test-plugin" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "hello") + scoped.Log(LogLevelError, "oops") + scoped.ReleasePluginScope() + + logs := ctx.GetPluginLogs() + if len(logs) != 2 { + t.Fatalf("expected 2 logs, got %d", len(logs)) + } + if logs[0].PluginName != "test-plugin" || logs[0].Level != LogLevelInfo || logs[0].Message != "hello" { + t.Errorf("unexpected first log: %+v", logs[0]) + } + if logs[1].Level != LogLevelError || logs[1].Message != "oops" { + t.Errorf("unexpected second log: %+v", logs[1]) + } +} + +func TestPluginLog_MultiplePlugins(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + + name1 := "plugin-a" + s1 := ctx.WithPluginScope(&name1) + s1.Log(LogLevelDebug, "a-msg") + s1.ReleasePluginScope() + + name2 := "plugin-b" + s2 := ctx.WithPluginScope(&name2) + s2.Log(LogLevelWarn, "b-msg") + s2.ReleasePluginScope() + + logs := ctx.GetPluginLogs() + if len(logs) != 2 { + t.Fatalf("expected 2 logs, got %d", len(logs)) + } + if logs[0].PluginName != "plugin-a" { + t.Errorf("expected plugin-a, got %s", logs[0].PluginName) + } + if logs[1].PluginName != "plugin-b" { + t.Errorf("expected plugin-b, got %s", logs[1].PluginName) + } +} + +func TestPluginLog_DrainTransfersOwnership(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + name := "drain-test" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "msg1") + scoped.ReleasePluginScope() + + drained := ctx.DrainPluginLogs() + if len(drained) != 1 { + t.Fatalf("expected 1 drained log, got %d", len(drained)) + } + + // After drain, GetPluginLogs should return nil + after := ctx.GetPluginLogs() + if after != nil { + t.Errorf("expected nil after drain, got %v", after) + } + + // Second drain should return nil + second := ctx.DrainPluginLogs() + if second != nil { + t.Errorf("expected nil on second drain, got %v", second) + } +} + +func TestPluginLog_ScopedContextValueDelegation(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + ctx.SetValue(BifrostContextKeyTraceID, "trace-123") + + name := "delegate-test" + scoped := ctx.WithPluginScope(&name) + + // Scoped should read from root + val := scoped.Value(BifrostContextKeyTraceID) + if val != "trace-123" { + t.Errorf("expected trace-123, got %v", val) + } + + // Scoped should write to root + type testContextKey string + const customKey testContextKey = "custom-key" + scoped.SetValue(customKey, "custom-val") + if ctx.Value(customKey) != "custom-val" { + t.Errorf("SetValue on scoped did not delegate to root") + } + + scoped.ReleasePluginScope() +} + +func TestPluginLog_PoolReuse(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + + // Create and release multiple scoped contexts to exercise the pool + for i := 0; i < 100; i++ { + name := "pool-test" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "pooled") + scoped.ReleasePluginScope() + } + + logs := ctx.DrainPluginLogs() + if len(logs) != 100 { + t.Errorf("expected 100 logs from pool reuse, got %d", len(logs)) + } +} diff --git a/core/schemas/embedding.go b/core/schemas/embedding.go index 1d8890dd1f..9ca2fb2cdf 100644 --- a/core/schemas/embedding.go +++ b/core/schemas/embedding.go @@ -116,14 +116,16 @@ type EmbeddingParameters struct { type EmbeddingData struct { Index int `json:"index"` Object string `json:"object"` // "embedding" - Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64 or [][]float64 + Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32 } type EmbeddingStruct struct { // Embedding responses preserve provider precision in normalized API output. - EmbeddingStr *string - EmbeddingArray []float64 - Embedding2DArray [][]float64 + EmbeddingStr *string + EmbeddingArray []float64 + Embedding2DArray [][]float64 + EmbeddingInt8Array []int8 // for int8 / binary formats + EmbeddingInt32Array []int32 // for uint8 / ubinary formats } func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { @@ -136,6 +138,12 @@ func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { if be.Embedding2DArray != nil { return MarshalSorted(be.Embedding2DArray) } + if be.EmbeddingInt8Array != nil { + return Marshal(be.EmbeddingInt8Array) + } + if be.EmbeddingInt32Array != nil { + return Marshal(be.EmbeddingInt32Array) + } return nil, fmt.Errorf("no embedding found") } @@ -161,5 +169,19 @@ func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error { return nil } - return fmt.Errorf("embedding field is neither a string nor an array of float64 nor a 2D array of float64") + // Try to unmarshal as a direct array of int8 + var int8Content []int8 + if err := Unmarshal(data, &int8Content); err == nil { + be.EmbeddingInt8Array = int8Content + return nil + } + + // Try to unmarshal as a direct array of int32 + var int32Content []int32 + if err := Unmarshal(data, &int32Content); err == nil { + be.EmbeddingInt32Array = int32Content + return nil + } + + return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32") } diff --git a/core/schemas/envvar.go b/core/schemas/envvar.go index 6c5f996f08..c8fe249699 100644 --- a/core/schemas/envvar.go +++ b/core/schemas/envvar.go @@ -117,6 +117,9 @@ func (e *EnvVar) Equals(other *EnvVar) bool { // Redacted returns a new SecretKey with the value redacted. func (e *EnvVar) Redacted() *EnvVar { + if e == nil { + return nil + } if e.Val == "" { return &EnvVar{ Val: "", @@ -144,6 +147,34 @@ func (e *EnvVar) Redacted() *EnvVar { } } +// MarshalJSON serializes the EnvVar to JSON. +// SECURITY: When the value was sourced from an environment variable, the resolved +// value is automatically redacted before being serialized. This ensures that secrets +// injected via env vars are never leaked through any JSON API response, regardless +// of whether the surrounding code remembered to call Redacted() explicitly. +// +// Plain (non-env) values are still emitted as-is — callers that want to mask those +// must continue using Redacted() at the field level (this matches the existing +// per-provider redaction logic). +// +// This does NOT affect: +// - GORM persistence (uses the Value() driver method, not JSON) +// - Encryption (operates on the Val field directly) +// - Internal LLM request paths (use GetValue() directly) +func (e EnvVar) MarshalJSON() ([]byte, error) { + type envVarAlias EnvVar + out := envVarAlias(e) + if e.FromEnv { + // Redact the resolved value but keep the env var reference and from_env flag + // so the UI still knows which env var backs this field. + redacted := e.Redacted() + if redacted != nil { + out = envVarAlias(*redacted) + } + } + return sonic.Marshal(out) +} + // UnmarshalJSON unmarshals the value from JSON. func (e *EnvVar) UnmarshalJSON(data []byte) error { // This is always going to be value @@ -259,6 +290,17 @@ func (e *EnvVar) IsFromEnv() bool { return e.FromEnv } +// IsSet returns true if the EnvVar has a resolved value or an environment variable reference. +// This should be used instead of GetValue() != "" when checking whether a field was configured, +// because env var references may have an empty Val before resolution (e.g., when the env var +// is not available in the current environment). +func (e *EnvVar) IsSet() bool { + if e == nil { + return false + } + return e.Val != "" || e.EnvVar != "" +} + // GetValue returns the value. func (e *EnvVar) GetValue() string { if e == nil { @@ -298,3 +340,14 @@ func (e *EnvVar) CoerceBool(defaultValue bool) bool { } return val } + +// IsDefined returns true if the EnvVar has a source (static value or env key) +func (e *EnvVar) IsDefined() bool { + if e == nil { + return false + } + if e.IsFromEnv() { + return e.EnvVar != "" + } + return e.Val != "" +} diff --git a/core/schemas/envvar_test.go b/core/schemas/envvar_test.go index 5a451b5058..9b22673ae8 100644 --- a/core/schemas/envvar_test.go +++ b/core/schemas/envvar_test.go @@ -419,3 +419,191 @@ func TestEnvVar_IsRedacted(t *testing.T) { }) } } + +// TestEnvVar_IsSet verifies the semantic difference between GetValue() != "" and IsSet(). +// IsSet() must return true when the EnvVar references an env var (regardless of whether +// that env var has been resolved to a non-empty Val). This is the property that the +// BeforeSave hooks rely on so env var references survive persistence. +func TestEnvVar_IsSet(t *testing.T) { + tests := []struct { + name string + input *EnvVar + expected bool + }{ + { + name: "nil envvar", + input: nil, + expected: false, + }, + { + name: "completely empty", + input: &EnvVar{}, + expected: false, + }, + { + name: "only Val set (plain value)", + input: &EnvVar{Val: "abc"}, + expected: true, + }, + { + name: "only EnvVar reference set (env not resolved on this server)", + input: &EnvVar{EnvVar: "env.MISSING", FromEnv: true}, + expected: true, + }, + { + name: "Val and EnvVar both set (env was resolved)", + input: &EnvVar{Val: "resolved-secret", EnvVar: "env.X", FromEnv: true}, + expected: true, + }, + { + name: "FromEnv true but no reference and no value", + input: &EnvVar{FromEnv: true}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.input.IsSet(); got != tt.expected { + t.Errorf("IsSet() = %v, want %v", got, tt.expected) + } + }) + } +} + +// TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues verifies that any EnvVar marshaled +// to JSON with FromEnv=true is automatically masked, regardless of whether the +// surrounding code remembered to call Redacted() explicitly. This is the defense-in-depth +// guarantee that prevents env-resolved secrets from leaking through unredacted fields. +func TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues(t *testing.T) { + tests := []struct { + name string + input EnvVar + wantValue string + wantEnvVar string + wantFromEnv bool + }{ + { + name: "env-backed long secret is redacted", + input: EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.OPENAI_API_KEY", FromEnv: true}, + wantValue: "sk-1************************mnop", + wantEnvVar: "env.OPENAI_API_KEY", + wantFromEnv: true, + }, + { + name: "env-backed short secret is fully masked", + input: EnvVar{Val: "12345678", EnvVar: "env.SHORT", FromEnv: true}, + wantValue: "********", + wantEnvVar: "env.SHORT", + wantFromEnv: true, + }, + { + name: "env-backed unresolved on this server keeps empty value", + input: EnvVar{Val: "", EnvVar: "env.MISSING", FromEnv: true}, + wantValue: "", + wantEnvVar: "env.MISSING", + wantFromEnv: true, + }, + { + name: "plain value (not from env) is NOT redacted", + input: EnvVar{Val: "2024-10-21", EnvVar: "", FromEnv: false}, + wantValue: "2024-10-21", + wantEnvVar: "", + wantFromEnv: false, + }, + { + name: "empty plain value passes through", + input: EnvVar{Val: "", EnvVar: "", FromEnv: false}, + wantValue: "", + wantEnvVar: "", + wantFromEnv: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.input) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var got struct { + Value string `json:"value"` + EnvVar string `json:"env_var"` + FromEnv bool `json:"from_env"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal of marshaled output failed: %v", err) + } + if got.Value != tt.wantValue { + t.Errorf("value: got %q, want %q", got.Value, tt.wantValue) + } + if got.EnvVar != tt.wantEnvVar { + t.Errorf("env_var: got %q, want %q", got.EnvVar, tt.wantEnvVar) + } + if got.FromEnv != tt.wantFromEnv { + t.Errorf("from_env: got %v, want %v", got.FromEnv, tt.wantFromEnv) + } + }) + } +} + +// TestEnvVar_MarshalJSON_DoesNotMutateOriginal ensures the auto-redaction in MarshalJSON +// does not mutate the receiver. The inference path calls GetValue() to build the actual +// HTTP request to the LLM provider, so the original Val must remain intact. +func TestEnvVar_MarshalJSON_DoesNotMutateOriginal(t *testing.T) { + original := EnvVar{Val: "real-secret-value", EnvVar: "env.SECRET", FromEnv: true} + if _, err := json.Marshal(original); err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if original.Val != "real-secret-value" { + t.Errorf("MarshalJSON mutated Val: got %q, want %q", original.Val, "real-secret-value") + } + if original.GetValue() != "real-secret-value" { + t.Errorf("GetValue() returns mutated value: got %q", original.GetValue()) + } +} + +// TestEnvVar_MarshalJSON_RoundTripIsRedacted verifies that a marshaled-then-unmarshaled +// env-backed EnvVar is recognized as redacted. The merge logic in provider_keys.go relies +// on this so it can detect "the UI sent back the same redacted value, don't overwrite". +func TestEnvVar_MarshalJSON_RoundTripIsRedacted(t *testing.T) { + original := EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.KEY", FromEnv: true} + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var roundTripped EnvVar + if err := json.Unmarshal(data, &roundTripped); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !roundTripped.IsRedacted() { + t.Errorf("Round-tripped env-backed value should be IsRedacted, got Val=%q", roundTripped.Val) + } + if roundTripped.EnvVar != "env.KEY" { + t.Errorf("env_var reference lost in round-trip: got %q, want %q", roundTripped.EnvVar, "env.KEY") + } +} + +// TestEnvVar_MarshalJSON_DoesNotAffectGetValue is a critical safety net: marshaling an +// EnvVar to JSON must NOT change what GetValue() returns. The inference path uses +// GetValue() to build outgoing LLM requests; if marshaling were to mutate the value, +// every request after a UI fetch would silently start using the redacted mask as the +// API key. +func TestEnvVar_MarshalJSON_DoesNotAffectGetValue(t *testing.T) { + os.Setenv("MY_REAL_API_KEY", "sk-real-secret-1234567890abcdef") + defer os.Unsetenv("MY_REAL_API_KEY") + + ev := NewEnvVar("env.MY_REAL_API_KEY") + if ev.GetValue() != "sk-real-secret-1234567890abcdef" { + t.Fatalf("setup: GetValue() = %q, want resolved env value", ev.GetValue()) + } + + // Marshaling would redact in the JSON output, but must not touch the in-memory Val. + if _, err := json.Marshal(ev); err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + if ev.GetValue() != "sk-real-secret-1234567890abcdef" { + t.Errorf("GetValue() returns mutated value after MarshalJSON: got %q", ev.GetValue()) + } +} diff --git a/core/schemas/images.go b/core/schemas/images.go index 1944f96d3f..fba3c2c08a 100644 --- a/core/schemas/images.go +++ b/core/schemas/images.go @@ -112,6 +112,22 @@ func (r *BifrostImageGenerationResponse) BackfillParams(req *BifrostRequest) { } } +// getModelFromRequest extracts the model from any image-related request. +func getModelFromRequest(req *BifrostRequest) string { + if req == nil { + return "" + } + switch { + case req.ImageGenerationRequest != nil: + return req.ImageGenerationRequest.Model + case req.ImageEditRequest != nil: + return req.ImageEditRequest.Model + case req.ImageVariationRequest != nil: + return req.ImageVariationRequest.Model + } + return "" +} + // getNumInputImagesSizeAndQualityFromRequest extracts request params for cost calculation. // Quality is only returned when it is one of low, medium, high, auto. func getNumInputImagesSizeAndQualityFromRequest(req *BifrostRequest) (numInputImages int, size string, quality string) { @@ -167,10 +183,12 @@ func normalizeImageQuality(q string) string { } type ImageGenerationResponseParameters struct { - Background string `json:"background,omitempty"` - OutputFormat string `json:"output_format,omitempty"` - Quality string `json:"quality,omitempty"` - Size string `json:"size,omitempty"` + Background string `json:"background,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + FinishReasons []*string `json:"finish_reasons,omitempty"` + Seeds []int `json:"seeds,omitempty"` } type ImageData struct { @@ -270,7 +288,7 @@ type ImageInput struct { } type ImageEditParameters struct { - Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", + Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", "remove_background", "erase_object", "recolor", "search_replace", "control_sketch", "control_structure", "style_guide", "style_transfer", "upscale_fast", "upscale_creative", "upscale_conservative" Background *string `json:"background,omitempty"` // "transparent", "opaque", "auto" InputFidelity *string `json:"input_fidelity,omitempty"` // "low", "high" Mask []byte `json:"mask,omitempty"` diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index 72b43cbc8e..af87cdc743 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -20,8 +20,25 @@ var ( ErrOAuth2TokenExpired = errors.New("oauth2 token expired") ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid") ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed") + ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session") + ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server") + ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired") ) +// MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires +// the user to authenticate before tool execution can proceed. +type MCPUserOAuthRequiredError struct { + MCPClientID string `json:"mcp_client_id"` + MCPClientName string `json:"mcp_client_name"` + AuthorizeURL string `json:"authorize_url"` + SessionID string `json:"session_id"` + Message string `json:"message"` +} + +func (e *MCPUserOAuthRequiredError) Error() string { + return e.Message +} + // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { @@ -46,9 +63,10 @@ type MCPConfig struct { } type MCPToolManagerConfig struct { - ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` - MaxAgentDepth int `json:"max_agent_depth"` - CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" + ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` + MaxAgentDepth int `json:"max_agent_depth"` + CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" + DisableAutoToolInject bool `json:"disable_auto_tool_inject,omitempty"` // When true, MCP tools are not injected into requests by default } const ( @@ -68,41 +86,48 @@ const ( type MCPAuthType string const ( - MCPAuthTypeNone MCPAuthType = "none" // No authentication - MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.) - MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication + MCPAuthTypeNone MCPAuthType = "none" // No authentication + MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.) + MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication (server-level, admin authenticates once) + MCPAuthTypePerUserOauth MCPAuthType = "per_user_oauth" // Per-user OAuth 2.0 authentication (each user authenticates individually) ) // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { - ID string `json:"client_id"` // Client ID - Name string `json:"name"` // Client name - IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client - ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) - ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) - StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) - AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) - OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) - State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) - Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) - InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) - ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Include-only list. + ID string `json:"client_id"` // Client ID + Name string `json:"name"` // Client name + IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) + ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) + OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) + State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) + Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) + AllowedExtraHeaders WhiteList `json:"allowed_extra_headers,omitempty"` // Allowlist of request-level headers that callers may forward to this MCP server at execution time + InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) + ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => include only the specified tools - ToolsToAutoExecute []string `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. + ToolsToAutoExecute WhiteList `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. // ToolsToAutoExecute semantics: // - ["*"] => all tools are auto-executed // - [] => no tools are auto-executed (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. - IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. - ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) - ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) - ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true. + ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) + ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys + + // Discovered tools for per-user OAuth clients (persisted so they survive restart) + DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name + DiscoveredToolNameMapping map[string]string `json:"-"` // Mapping from sanitized tool names to original MCP names } // NewMCPClientConfigFromMap creates a new MCP client config from a map[string]any. @@ -147,6 +172,9 @@ func (c *MCPClientConfig) HttpHeaders(ctx context.Context, oauth2Provider OAuth2 for key, value := range c.Headers { headers[key] = value.GetValue() } + case MCPAuthTypePerUserOauth: + // Per-user OAuth: headers are injected per-call in executeToolInternal, not at connection level + return headers, nil case MCPAuthTypeNone: // No headers to add default: @@ -179,9 +207,10 @@ type MCPStdioConfig struct { type MCPConnectionState string const ( - MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use - MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected - MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used + MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use + MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected + MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used + MCPConnectionStatePendingTools MCPConnectionState = "pending_tools" // Connected but tools not yet populated ) // MCPClientState represents a connected MCP client with its configuration and tools. diff --git a/core/schemas/models.go b/core/schemas/models.go index 5a0e8588c1..32b82bb104 100644 --- a/core/schemas/models.go +++ b/core/schemas/models.go @@ -138,7 +138,7 @@ type Model struct { ID string `json:"id"` CanonicalSlug *string `json:"canonical_slug,omitempty"` Name *string `json:"name,omitempty"` - Deployment *string `json:"deployment,omitempty"` // Name of the actual deployment + Alias *string `json:"alias,omitempty"` // Provider API identifier this model alias maps to (e.g. Azure deployment name, Bedrock ARN) Created *int64 `json:"created,omitempty"` ContextLength *int `json:"context_length,omitempty"` MaxInputTokens *int `json:"max_input_tokens,omitempty"` diff --git a/core/schemas/models_test.go b/core/schemas/models_test.go index 3e60fdda76..b9748952bd 100644 --- a/core/schemas/models_test.go +++ b/core/schemas/models_test.go @@ -94,7 +94,7 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { Error: &ErrorField{Message: "unauthorized"}, ExtraFields: BifrostErrorExtraFields{ Provider: "openai", - ModelRequested: "gpt-4", + OriginalModelRequested: "gpt-4", }, } keyStatus := KeyStatus{ @@ -112,6 +112,6 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { // Error fields other than key_statuses should be preserved dataStr := string(data) assert.Contains(t, dataStr, `"unauthorized"`) - assert.Contains(t, dataStr, `"model_requested":"gpt-4"`) + assert.Contains(t, dataStr, `"original_model_requested":"gpt-4"`) assert.Contains(t, dataStr, `"status_code":401`) } diff --git a/core/schemas/mux.go b/core/schemas/mux.go index f899f41739..24943d3fbd 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -1258,6 +1258,10 @@ func (responsesResp *BifrostResponsesResponse) ToBifrostChatResponse() *BifrostC Videos: responsesResp.Videos, } + if responsesResp.ID != nil { + chatResp.ID = *responsesResp.ID + } + // Create Choices from ResponsesResponse if len(responsesResp.Output) > 0 { // Convert ResponsesMessages back to ChatMessages @@ -2013,3 +2017,362 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes return responses } + +// ToBifrostChatResponse converts a BifrostResponsesStreamResponse chunk to a BifrostChatResponse (chat.completion.chunk). +func (rsr *BifrostResponsesStreamResponse) ToBifrostChatResponse() *BifrostChatResponse { + if rsr == nil { + return nil + } + + extraFields := rsr.ExtraFields + extraFields.RequestType = ChatCompletionStreamRequest + + resp := &BifrostChatResponse{ + Object: "chat.completion.chunk", + ExtraFields: extraFields, + SearchResults: rsr.SearchResults, + Videos: rsr.Videos, + Citations: rsr.Citations, + } + + if rsr.Response != nil { + if rsr.Response.ID != nil { + resp.ID = *rsr.Response.ID + } + resp.Created = rsr.Response.CreatedAt + resp.Model = rsr.Response.Model + } + + switch rsr.Type { + case ResponsesStreamResponseTypeOutputTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Content: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeReasoningSummaryTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Reasoning: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeRefusalDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Refusal: rsr.Refusal, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeOutputItemAdded: + if rsr.Item == nil || rsr.Item.Type == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + switch *rsr.Item.Type { + case ResponsesMessageTypeFunctionCall: + if rsr.Item.ResponsesToolMessage == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + funcType := "function" + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Type: &funcType, + ID: rsr.Item.ResponsesToolMessage.CallID, + Function: ChatAssistantMessageToolCallFunction{ + Name: rsr.Item.ResponsesToolMessage.Name, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesMessageTypeMessage: + role := "assistant" + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Role: &role, + }, + }, + }, + } + return resp + + default: + // reasoning, file_search_call, web_search_call, etc. — no chat equivalent, + // actual content arrives via separate delta events. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + case ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + if rsr.Delta == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Function: ChatAssistantMessageToolCallFunction{ + Arguments: *rsr.Delta, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeCompleted, ResponsesStreamResponseTypeIncomplete: + finishReason := string(BifrostFinishReasonStop) + if rsr.Type == ResponsesStreamResponseTypeIncomplete { + finishReason = string(BifrostFinishReasonLength) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + FinishReason: &finishReason, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + if rsr.Response != nil { + if rsr.Response.Usage != nil { + resp.Usage = rsr.Response.Usage.ToBifrostLLMUsage() + } + // Check for tool_calls finish reason + if rsr.Type == ResponsesStreamResponseTypeCompleted { + for _, output := range rsr.Response.Output { + if output.Type != nil && *output.Type == ResponsesMessageTypeFunctionCall { + finishReason = string(BifrostFinishReasonToolCalls) + resp.Choices[0].FinishReason = &finishReason + break + } + } + } + } + return resp + + default: + // Lifecycle events (created, in_progress, content_part.added/done, output_text.done, + // output_item.done, function_call_arguments.done, etc.) → empty chat chunk with no content. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } +} + +// ============================================================================= +// RESPONSE CONVERSION METHODS +// ============================================================================= + +// ToBifrostTextCompletionResponse converts a BifrostChatResponse to a BifrostTextCompletionResponse +func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCompletionResponse { + if cr == nil { + return nil + } + + if len(cr.Choices) == 0 { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + choice := cr.Choices[0] + + // Handle streaming response choice + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: choice.ChatStreamResponseChoice.Delta.Content, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Handle non-streaming response choice + if choice.ChatNonStreamResponseChoice != nil { + msg := choice.ChatNonStreamResponseChoice.Message + var textContent *string + if msg != nil && msg.Content != nil { + if msg.Content.ContentStr != nil { + textContent = msg.Content.ContentStr + } else if len(msg.Content.ContentBlocks) > 0 { + var sb strings.Builder + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + sb.WriteString(*block.Text) + } + } + if sb.Len() > 0 { + s := sb.String() + textContent = &s + } + } + } + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: textContent, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Fallback case - return basic response structure + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } +} diff --git a/core/schemas/oauth.go b/core/schemas/oauth.go index 1a953c3cc6..7ff3bb9362 100644 --- a/core/schemas/oauth.go +++ b/core/schemas/oauth.go @@ -7,7 +7,7 @@ import ( // OauthProvider interface defines OAuth operations type OAuth2Provider interface { - // GetAccessToken retrieves the access token for a given oauth_config_id + // GetAccessToken retrieves the access token for a given oauth_config_id (server-level OAuth) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) // RefreshAccessToken refreshes the access token for a given oauth_config_id @@ -18,6 +18,31 @@ type OAuth2Provider interface { // RevokeToken revokes the OAuth token RevokeToken(ctx context.Context, oauthConfigID string) error + + // Per-user OAuth methods + + // GetUserAccessToken retrieves the access token for a per-user OAuth session. + // If the token is expired, it automatically attempts a refresh. + GetUserAccessToken(ctx context.Context, sessionToken string) (string, error) + + // GetUserAccessTokenByIdentity retrieves the upstream access token for a user + // identified by virtualKeyID, userID, or sessionToken (fallback), for a specific + // MCP client. Tokens looked up by identity persist across sessions. + GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error) + + // InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL. + // Returns (flow initiation details, session ID for polling, error). + InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*OAuth2FlowInitiation, string, error) + + // CompleteUserOAuthFlow handles the OAuth callback for a per-user flow. + // Returns the session token that the user should send on subsequent requests. + CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error) + + // RefreshUserAccessToken refreshes a per-user OAuth access token. + RefreshUserAccessToken(ctx context.Context, sessionToken string) error + + // RevokeUserToken revokes a per-user OAuth token and marks the session as revoked. + RevokeUserToken(ctx context.Context, sessionToken string) error } // OauthConfig represents OAuth client configuration diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index f9ea18a4b3..5e0d068718 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -313,9 +313,15 @@ type ObservabilityPlugin interface { // // Implementations should: // - Convert the trace to their backend's format - // - Send the trace to the backend (can be async) + // - Send the trace to the backend (can be async, but see retention note below) // - Handle errors gracefully (log and continue) // // The context passed is a fresh background context, not the request context. + // + // Retention: implementations MUST NOT retain the *Trace pointer after Inject + // returns. The caller releases the trace back to a sync.Pool immediately after + // Inject completes, so any background goroutine that still references it will + // race with pool reuse. If a plugin needs to forward the trace asynchronously, + // it must copy the data it needs before returning. Inject(ctx context.Context, trace *Trace) error } diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 6fe0615b06..5d28002f9f 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -8,18 +8,18 @@ import ( ) const ( - DefaultMaxRetries = 0 - DefaultRetryBackoffInitial = 500 * time.Millisecond - DefaultRetryBackoffMax = 5 * time.Second + DefaultMaxRetries = 0 + DefaultRetryBackoffInitial = 500 * time.Millisecond + DefaultRetryBackoffMax = 5 * time.Second DefaultRequestTimeoutInSeconds = 30 - DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops - DefaultBufferSize = 5000 - DefaultConcurrency = 1000 - DefaultStreamBufferSize = 256 - DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection - DefaultMaxConnsPerHost = 5000 - MaxConnsPerHostUpperBound = 10000 - DefaultMaxIdleConnsPerHost = 40 + DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops + DefaultBufferSize = 5000 + DefaultConcurrency = 1000 + DefaultStreamBufferSize = 256 + DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection + DefaultMaxConnsPerHost = 5000 + MaxConnsPerHostUpperBound = 10000 + DefaultMaxIdleConnsPerHost = 40 ) // Pre-defined errors for provider operations @@ -52,18 +52,18 @@ const ( // - When marshaling to JSON: a time.Duration is converted to milliseconds type NetworkConfig struct { // BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama) - BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) - DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests - MaxRetries int `json:"max_retries"` // Maximum number of retries - RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) - RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) - InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections - CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections + BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests + MaxRetries int `json:"max_retries"` // Maximum number of retries + RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) + RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections + CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"` // Idle timeout per stream chunk (0 = use default 60s) - MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) - EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) - BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") + MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) + EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) + BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") } // UnmarshalJSON customizes JSON unmarshaling for NetworkConfig. @@ -409,67 +409,6 @@ type CustomProviderConfig struct { RequestPathOverrides map[RequestType]string `json:"request_path_overrides,omitempty"` // Mapping of request type to its custom path which will override the default path of the provider (not allowed for Bedrock) } -type PricingOverrideMatchType string - -const ( - PricingOverrideMatchExact PricingOverrideMatchType = "exact" - PricingOverrideMatchWildcard PricingOverrideMatchType = "wildcard" - PricingOverrideMatchRegex PricingOverrideMatchType = "regex" -) - -// ProviderPricingOverride contains a partial pricing patch applied at lookup time. -// Any nil field falls back to the base pricing data. -type ProviderPricingOverride struct { - ModelPattern string `json:"model_pattern"` - MatchType PricingOverrideMatchType `json:"match_type"` - RequestTypes []RequestType `json:"request_types,omitempty"` - - // Basic token pricing - InputCostPerToken *float64 `json:"input_cost_per_token,omitempty"` - OutputCostPerToken *float64 `json:"output_cost_per_token,omitempty"` - - // Additional pricing for media - InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` - InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` - - // Character-based pricing - InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` - - // Pricing above 128k tokens - InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` - InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` - InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` - InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` - OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` - - // Pricing above 200k tokens - InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` - OutputCostPerTokenAbove200kTokens *float64 `json:"output_cost_per_token_above_200k_tokens,omitempty"` - CacheCreationInputTokenCostAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"` - CacheReadInputTokenCostAbove200kTokens *float64 `json:"cache_read_input_token_cost_above_200k_tokens,omitempty"` - - // Cache and batch pricing - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost,omitempty"` - InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` - OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` - - // Image generation pricing - InputCostPerImageToken *float64 `json:"input_cost_per_image_token,omitempty"` - OutputCostPerImageToken *float64 `json:"output_cost_per_image_token,omitempty"` - InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` - OutputCostPerImage *float64 `json:"output_cost_per_image,omitempty"` - OutputCostPerImageAbove1024x1024Pixels *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"` - OutputCostPerImageAbove1024x1024PixelsPremium *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"` - OutputCostPerImageAbove2048x2048Pixels *float64 `json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"` - OutputCostPerImageAbove4096x4096Pixels *float64 `json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"` - OutputCostPerImageLowQuality *float64 `json:"output_cost_per_image_low_quality,omitempty"` - OutputCostPerImageMediumQuality *float64 `json:"output_cost_per_image_medium_quality,omitempty"` - OutputCostPerImageHighQuality *float64 `json:"output_cost_per_image_high_quality,omitempty"` - OutputCostPerImageAutoQuality *float64 `json:"output_cost_per_image_auto_quality,omitempty"` - CacheReadInputImageTokenCost *float64 `json:"cache_read_input_image_token_cost,omitempty"` -} - // IsOperationAllowed checks if a specific operation is allowed for this custom provider func (cpc *CustomProviderConfig) IsOperationAllowed(operation RequestType) bool { if cpc == nil || cpc.AllowedRequests == nil { @@ -485,14 +424,13 @@ type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` // Network configuration ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings // Logger instance, can be provided by the user or bifrost default logger is used if not provided - Logger Logger `json:"-"` - ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) - SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) - StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) - CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` - OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` - PricingOverrides []ProviderPricingOverride `json:"pricing_overrides,omitempty"` + Logger Logger `json:"-"` + ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) + SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) + StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) + CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` + OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` } // OpenAIConfig holds OpenAI-specific provider configuration. diff --git a/core/schemas/realtime.go b/core/schemas/realtime.go index e1e20d7bf4..ec4fd6789d 100644 --- a/core/schemas/realtime.go +++ b/core/schemas/realtime.go @@ -19,33 +19,75 @@ const ( // Server-to-client event types (received from the provider, forwarded to client) const ( - RTEventSessionCreated RealtimeEventType = "session.created" - RTEventSessionUpdated RealtimeEventType = "session.updated" - RTEventConversationCreated RealtimeEventType = "conversation.created" - RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" - RTEventConversationItemDone RealtimeEventType = "conversation.item.done" - RTEventResponseCreated RealtimeEventType = "response.created" - RTEventResponseDone RealtimeEventType = "response.done" - RTEventResponseTextDelta RealtimeEventType = "response.text.delta" - RTEventResponseTextDone RealtimeEventType = "response.text.done" - RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" - RTEventResponseAudioDone RealtimeEventType = "response.audio.done" - RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" - RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" - RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" - RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" - RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" - RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" - RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" - RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" - RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" - RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" - RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" - RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" - RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" - RTEventError RealtimeEventType = "error" + RTEventSessionCreated RealtimeEventType = "session.created" + RTEventSessionUpdated RealtimeEventType = "session.updated" + RTEventConversationCreated RealtimeEventType = "conversation.created" + RTEventConversationItemAdded RealtimeEventType = "conversation.item.added" + RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" + RTEventConversationItemRetrieved RealtimeEventType = "conversation.item.retrieved" + RTEventConversationItemDone RealtimeEventType = "conversation.item.done" + RTEventResponseCreated RealtimeEventType = "response.created" + RTEventResponseDone RealtimeEventType = "response.done" + RTEventResponseTextDelta RealtimeEventType = "response.text.delta" + RTEventResponseTextDone RealtimeEventType = "response.text.done" + RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" + RTEventResponseAudioDone RealtimeEventType = "response.audio.done" + RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" + RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" + RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" + RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" + RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" + RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" + RTEventRateLimitsUpdated RealtimeEventType = "rate_limits.updated" + RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" + RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" + RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" + RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" + RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" + RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" + RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" + RTEventError RealtimeEventType = "error" ) +// IsRealtimeConversationItemEventType reports whether the event carries a +// canonical conversation item payload after provider translation. +func IsRealtimeConversationItemEventType(eventType RealtimeEventType) bool { + switch eventType { + case RTEventConversationItemCreate, + RTEventConversationItemAdded, + RTEventConversationItemCreated, + RTEventConversationItemRetrieved, + RTEventConversationItemDone: + return true + default: + return false + } +} + +// IsRealtimeUserInputEvent reports whether the event represents a finalized +// user input item in the canonical Bifrost realtime schema. +func IsRealtimeUserInputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Role == "user" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeToolOutputEvent reports whether the event represents a finalized +// tool output item in the canonical Bifrost realtime schema. +func IsRealtimeToolOutputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Type == "function_call_output" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeInputTranscriptEvent reports whether the event carries a finalized +// input-audio transcript in the canonical Bifrost realtime schema. +func IsRealtimeInputTranscriptEvent(event *BifrostRealtimeEvent) bool { + return event != nil && event.Type == RTEventInputAudioTransCompleted +} + // BifrostRealtimeEvent is the unified Bifrost envelope for all Realtime events. // Provider converters translate between this format and the provider-native protocol. type BifrostRealtimeEvent struct { @@ -58,36 +100,42 @@ type BifrostRealtimeEvent struct { Audio []byte `json:"audio,omitempty"` Error *RealtimeError `json:"error,omitempty"` + // ExtraParams preserves provider-specific top-level event fields that are not + // promoted into the common Bifrost schema. + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` + // RawData preserves the original provider event for pass-through or debugging. RawData json.RawMessage `json:"raw_data,omitempty"` } // RealtimeSession describes session configuration for the Realtime connection. type RealtimeSession struct { - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Instructions string `json:"instructions,omitempty"` - Voice string `json:"voice,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` - TurnDetection json.RawMessage `json:"turn_detection,omitempty"` - InputAudioFormat string `json:"input_audio_format,omitempty"` - OutputAudioType string `json:"output_audio_type,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Instructions string `json:"instructions,omitempty"` + Voice string `json:"voice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` + TurnDetection json.RawMessage `json:"turn_detection,omitempty"` + InputAudioFormat string `json:"input_audio_format,omitempty"` + OutputAudioType string `json:"output_audio_type,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeItem represents a conversation item in the Realtime protocol. type RealtimeItem struct { - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Role string `json:"role,omitempty"` - Status string `json:"status,omitempty"` - Content json.RawMessage `json:"content,omitempty"` - Name string `json:"name,omitempty"` - CallID string `json:"call_id,omitempty"` - Arguments string `json:"arguments,omitempty"` - Output string `json:"output,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Role string `json:"role,omitempty"` + Status string `json:"status,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + CallID string `json:"call_id,omitempty"` + Arguments string `json:"arguments,omitempty"` + Output string `json:"output,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeDelta carries incremental content for streaming events. @@ -103,10 +151,28 @@ type RealtimeDelta struct { // RealtimeError describes an error from the Realtime API. type RealtimeError struct { - Type string `json:"type,omitempty"` - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` - Param string `json:"param,omitempty"` + Type string `json:"type,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param string `json:"param,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` +} + +// RealtimeSessionEndpointType identifies the public ephemeral-token endpoint +// shape the client called so providers can preserve versioned behavior. +type RealtimeSessionEndpointType string + +const ( + RealtimeSessionEndpointClientSecrets RealtimeSessionEndpointType = "client_secrets" + RealtimeSessionEndpointSessions RealtimeSessionEndpointType = "sessions" +) + +// RealtimeSessionRoute describes a provider-registered public route for +// ephemeral-token creation. +type RealtimeSessionRoute struct { + Path string + EndpointType RealtimeSessionEndpointType + DefaultProvider ModelProvider } // RealtimeProvider is an optional interface that providers can implement to @@ -116,6 +182,129 @@ type RealtimeProvider interface { SupportsRealtimeAPI() bool RealtimeWebSocketURL(key Key, model string) string RealtimeHeaders(key Key) map[string]string + // SupportsRealtimeWebRTC reports whether the provider supports WebRTC SDP exchange. + SupportsRealtimeWebRTC() bool + // ExchangeRealtimeWebRTCSDP performs the provider-specific SDP signaling exchange. + // The provider owns the HTTP specifics (URL, headers, body format). + // session may be nil if the signaling format doesn't include session config. + ExchangeRealtimeWebRTCSDP(ctx *BifrostContext, key Key, model string, sdp string, session json.RawMessage) (string, *BifrostError) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*BifrostRealtimeEvent, error) ToProviderRealtimeEvent(bifrostEvent *BifrostRealtimeEvent) (json.RawMessage, error) + // ShouldStartRealtimeTurn reports whether the canonical client-side event + // should start pre-hooks. Providers without an explicit turn-start signal + // return false and rely on finalize-time fallback hooks. + ShouldStartRealtimeTurn(event *BifrostRealtimeEvent) bool + // RealtimeTurnFinalEvent returns the canonical provider event that completes + // a turn and should trigger post-hooks. + RealtimeTurnFinalEvent() RealtimeEventType + RealtimeWebRTCDataChannelLabel() string + RealtimeWebSocketSubprotocol() string + ShouldForwardRealtimeEvent(event *BifrostRealtimeEvent) bool + ShouldAccumulateRealtimeOutput(eventType RealtimeEventType) bool +} + +// RealtimeLegacyWebRTCProvider is an optional interface for providers that +// support the beta WebRTC handshake (e.g., OpenAI's /v1/realtime). +// Only checked for legacy integration routes via type assertion. +// Takes SDP offer + optional session JSON, same as ExchangeRealtimeWebRTCSDP +// but targets the provider's legacy/beta endpoint. +type RealtimeLegacyWebRTCProvider interface { + ExchangeLegacyRealtimeWebRTCSDP(ctx *BifrostContext, key Key, sdp string, session json.RawMessage, model string) (string, *BifrostError) +} + +// RealtimeUsageExtractor lets providers parse terminal-turn usage/output from +// their native wire payloads without coupling handlers to a specific protocol. +type RealtimeUsageExtractor interface { + ExtractRealtimeTurnUsage(terminalEventRaw []byte) *BifrostLLMUsage + ExtractRealtimeTurnOutput(terminalEventRaw []byte) *ChatMessage +} + +// RealtimeSessionProvider is an optional interface for providers that can mint +// short-lived client secrets for browser/client-side Realtime connections. +// Checked via type assertion: provider.(RealtimeSessionProvider). +type RealtimeSessionProvider interface { + CreateRealtimeClientSecret(ctx *BifrostContext, key Key, endpointType RealtimeSessionEndpointType, rawRequest json.RawMessage) (*BifrostPassthroughResponse, *BifrostError) +} + +// ParseRealtimeEvent decodes a client/provider realtime event while preserving +// unknown top-level fields in ExtraParams for provider-specific round-tripping. +func ParseRealtimeEvent(raw []byte) (*BifrostRealtimeEvent, error) { + type realtimeEventAlias struct { + Type RealtimeEventType `json:"type"` + EventID string `json:"event_id,omitempty"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Delta *RealtimeDelta `json:"delta,omitempty"` + Audio []byte `json:"audio,omitempty"` + Error *RealtimeError `json:"error,omitempty"` + } + + var alias realtimeEventAlias + if err := Unmarshal(raw, &alias); err != nil { + return nil, err + } + + event := &BifrostRealtimeEvent{ + Type: alias.Type, + EventID: alias.EventID, + Session: alias.Session, + Item: alias.Item, + Delta: alias.Delta, + Audio: alias.Audio, + Error: alias.Error, + } + + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, err + } + savedSession := root["session"] + savedItem := root["item"] + savedError := root["error"] + for _, key := range []string{"type", "event_id", "session", "item", "delta", "audio", "error", "raw_data"} { + delete(root, key) + } + if len(root) > 0 { + event.ExtraParams = root + } + if event.Session != nil { + var sessionRoot map[string]json.RawMessage + if len(savedSession) > 0 && Unmarshal(savedSession, &sessionRoot) == nil { + for _, key := range []string{ + "id", "model", "modalities", "instructions", "voice", "temperature", + "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools", + } { + delete(sessionRoot, key) + } + if len(sessionRoot) > 0 { + event.Session.ExtraParams = sessionRoot + } + } + } + if event.Item != nil { + var itemRoot map[string]json.RawMessage + if len(savedItem) > 0 && Unmarshal(savedItem, &itemRoot) == nil { + for _, key := range []string{ + "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output", + } { + delete(itemRoot, key) + } + if len(itemRoot) > 0 { + event.Item.ExtraParams = itemRoot + } + } + } + if event.Error != nil { + var errorRoot map[string]json.RawMessage + if len(savedError) > 0 && Unmarshal(savedError, &errorRoot) == nil { + for _, key := range []string{"type", "code", "message", "param"} { + delete(errorRoot, key) + } + if len(errorRoot) > 0 { + event.Error.ExtraParams = errorRoot + } + } + } + + return event, nil } diff --git a/core/schemas/realtime_client_secrets.go b/core/schemas/realtime_client_secrets.go new file mode 100644 index 0000000000..ae97b573a1 --- /dev/null +++ b/core/schemas/realtime_client_secrets.go @@ -0,0 +1,66 @@ +package schemas + +import ( + "bytes" + "encoding/json" + "strings" +) + +// ParseRealtimeClientSecretBody parses a realtime client-secret request body +// into a mutable raw JSON map while preserving unknown fields. +func ParseRealtimeClientSecretBody(raw json.RawMessage) (map[string]json.RawMessage, *BifrostError) { + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, NewRealtimeClientSecretBodyError(400, "invalid_request_error", "invalid JSON body", err) + } + return root, nil +} + +// ExtractRealtimeClientSecretModel extracts the model from either session.model +// or the legacy top-level model field. +func ExtractRealtimeClientSecretModel(root map[string]json.RawMessage) (string, *BifrostError) { + if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 && !bytes.Equal(sessionJSON, []byte("null")) { + var session map[string]json.RawMessage + if err := Unmarshal(sessionJSON, &session); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session must be an object", err) + } + if modelJSON, ok := session["model"]; ok { + var sessionModel string + if err := Unmarshal(modelJSON, &sessionModel); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model must be a string", err) + } + if strings.TrimSpace(sessionModel) != "" { + return strings.TrimSpace(sessionModel), nil + } + } + } + + if modelJSON, ok := root["model"]; ok { + var model string + if err := Unmarshal(modelJSON, &model); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "model must be a string", err) + } + if strings.TrimSpace(model) != "" { + return strings.TrimSpace(model), nil + } + } + + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model or model is required", nil) +} + +// NewRealtimeClientSecretBodyError builds a standard invalid-request style error +// for HTTP realtime client-secret request parsing/validation. +func NewRealtimeClientSecretBodyError(status int, errorType, message string, err error) *BifrostError { + return &BifrostError{ + IsBifrostError: false, + StatusCode: Ptr(status), + Error: &ErrorField{ + Type: Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: BifrostErrorExtraFields{ + RequestType: RealtimeRequest, + }, + } +} diff --git a/core/schemas/realtime_client_secrets_test.go b/core/schemas/realtime_client_secrets_test.go new file mode 100644 index 0000000000..dfd8f8b1d3 --- /dev/null +++ b/core/schemas/realtime_client_secrets_test.go @@ -0,0 +1,40 @@ +package schemas + +import ( + "encoding/json" + "testing" +) + +func TestExtractRealtimeClientSecretModel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "openai/gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "openai/gpt-4o-realtime-preview") + } +} + +func TestExtractRealtimeClientSecretModelFallbackTopLevel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"model":"gpt-4o-realtime-preview"}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } +} diff --git a/core/schemas/realtime_test.go b/core/schemas/realtime_test.go new file mode 100644 index 0000000000..69e9e403c8 --- /dev/null +++ b/core/schemas/realtime_test.go @@ -0,0 +1,68 @@ +package schemas + +import "testing" + +func TestIsRealtimeConversationItemEventType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + eventType RealtimeEventType + want bool + }{ + {name: "create", eventType: RTEventConversationItemCreate, want: true}, + {name: "added", eventType: RTEventConversationItemAdded, want: true}, + {name: "created", eventType: RTEventConversationItemCreated, want: true}, + {name: "retrieved", eventType: RTEventConversationItemRetrieved, want: true}, + {name: "done", eventType: RTEventConversationItemDone, want: true}, + {name: "response done", eventType: RTEventResponseDone, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsRealtimeConversationItemEventType(tt.eventType); got != tt.want { + t.Fatalf("IsRealtimeConversationItemEventType(%q) = %v, want %v", tt.eventType, got, tt.want) + } + }) + } +} + +func TestRealtimeCanonicalEventClassifiers(t *testing.T) { + t.Parallel() + + userEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemAdded, + Item: &RealtimeItem{ + Role: "user", + Type: "message", + }, + } + if !IsRealtimeUserInputEvent(userEvent) { + t.Fatal("expected conversation.item.added user event to be classified as realtime user input") + } + if IsRealtimeToolOutputEvent(userEvent) { + t.Fatal("did not expect conversation.item.added user event to be classified as realtime tool output") + } + + toolEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemRetrieved, + Item: &RealtimeItem{ + Type: "function_call_output", + }, + } + if !IsRealtimeToolOutputEvent(toolEvent) { + t.Fatal("expected function_call_output item to be classified as realtime tool output") + } + if IsRealtimeUserInputEvent(toolEvent) { + t.Fatal("did not expect function_call_output item to be classified as realtime user input") + } + + transcriptEvent := &BifrostRealtimeEvent{Type: RTEventInputAudioTransCompleted} + if !IsRealtimeInputTranscriptEvent(transcriptEvent) { + t.Fatal("expected input audio transcription completion to be classified as transcript event") + } + if IsRealtimeInputTranscriptEvent(&BifrostRealtimeEvent{Type: RTEventInputAudioTransDelta}) { + t.Fatal("did not expect input audio transcription delta to be classified as transcript event") + } +} diff --git a/core/schemas/trace.go b/core/schemas/trace.go index 9a69980d3c..d6862d4d4e 100644 --- a/core/schemas/trace.go +++ b/core/schemas/trace.go @@ -8,6 +8,7 @@ import ( // Trace represents a distributed trace that captures the full lifecycle of a request type Trace struct { + RequestID string // Request ID for the trace TraceID string // Unique identifier for this trace ParentID string // Parent trace ID from incoming W3C traceparent header RootSpan *Span // The root span of this trace @@ -15,6 +16,7 @@ type Trace struct { StartTime time.Time // When the trace started EndTime time.Time // When the trace completed Attributes map[string]any // Additional attributes for the trace + PluginLogs []PluginLogEntry // Plugin log entries accumulated during request processing mu sync.Mutex // Mutex for thread-safe span operations } @@ -37,15 +39,49 @@ func (t *Trace) GetSpan(spanID string) *Span { return nil } +// GetRequestID retrieves the request ID from the trace +func (t *Trace) GetRequestID() string { + t.mu.Lock() + defer t.mu.Unlock() + return t.RequestID +} + +// SetRequestID sets the request ID for the trace +func (t *Trace) SetRequestID(requestID string) { + t.mu.Lock() + defer t.mu.Unlock() + t.RequestID = requestID +} + // Reset clears the trace for reuse from pool func (t *Trace) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + t.RequestID = "" t.TraceID = "" t.ParentID = "" t.RootSpan = nil + for i := range t.Spans { + t.Spans[i] = nil + } t.Spans = t.Spans[:0] t.StartTime = time.Time{} t.EndTime = time.Time{} t.Attributes = nil + for i := range t.PluginLogs { + t.PluginLogs[i] = PluginLogEntry{} + } + t.PluginLogs = t.PluginLogs[:0] +} + +// AppendPluginLogs appends plugin log entries to the trace in a thread-safe manner. +func (t *Trace) AppendPluginLogs(logs []PluginLogEntry) { + if len(logs) == 0 { + return + } + t.mu.Lock() + t.PluginLogs = append(t.PluginLogs, logs...) + t.mu.Unlock() } // Span represents a single operation within a trace diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go index 06f6487f8c..23c5d4cc4c 100644 --- a/core/schemas/tracer.go +++ b/core/schemas/tracer.go @@ -14,7 +14,8 @@ type SpanHandle interface{} // This is the return type for tracer's streaming accumulation methods. type StreamAccumulatorResult struct { RequestID string // Request ID - Model string // Model used + RequestedModel string // Original model requested by the caller + ResolvedModel string // Actual model used by the provider (equals RequestedModel when no alias mapping exists) Provider ModelProvider // Provider used Status string // Status of the stream Latency int64 // Latency in milliseconds @@ -38,7 +39,8 @@ type StreamAccumulatorResult struct { type Tracer interface { // CreateTrace creates a new trace with optional parent ID and returns the trace ID. // The parentID can be extracted from W3C traceparent headers for distributed tracing. - CreateTrace(parentID string) string + // The requestID is optional and can be used to identify the request. + CreateTrace(parentID string, requestID ...string) string // EndTrace completes a trace and returns the trace data for observation/export. // After this call, the trace is removed from active tracking and returned for cleanup. @@ -68,7 +70,7 @@ type Tracer interface { // PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. // This includes output messages, tokens, usage stats, and error information if present. - PopulateLLMResponseAttributes(handle SpanHandle, resp *BifrostResponse, err *BifrostError) + PopulateLLMResponseAttributes(ctx *BifrostContext, handle SpanHandle, resp *BifrostResponse, err *BifrostError) // StoreDeferredSpan stores a span handle for later completion (used for streaming requests). // The span handle is stored keyed by trace ID so it can be retrieved when the stream completes. @@ -111,6 +113,14 @@ type Tracer interface { // The ctx parameter must contain the stream end indicator for proper final chunk detection. ProcessStreamingChunk(traceID string, isFinalChunk bool, result *BifrostResponse, err *BifrostError) *StreamAccumulatorResult + // AttachPluginLogs appends plugin log entries to the trace identified by traceID. + // Thread-safe. Should be called after plugin hooks complete, before trace completion. + AttachPluginLogs(traceID string, logs []PluginLogEntry) + + // CompleteAndFlushTrace ends a trace, exports it to observability plugins, and + // releases the trace resources. Used by transports that bypass normal HTTP trace completion. + CompleteAndFlushTrace(traceID string) + // Stop releases resources associated with the tracer. // Should be called during shutdown to stop background goroutines. Stop() @@ -121,7 +131,7 @@ type Tracer interface { type NoOpTracer struct{} // CreateTrace returns an empty string (no trace created). -func (n *NoOpTracer) CreateTrace(_ string) string { return "" } +func (n *NoOpTracer) CreateTrace(_ string, _ ...string) string { return "" } // EndTrace returns nil (no trace to end). func (n *NoOpTracer) EndTrace(_ string) *Trace { return nil } @@ -144,7 +154,7 @@ func (n *NoOpTracer) AddEvent(_ SpanHandle, _ string, _ map[string]any) {} func (n *NoOpTracer) PopulateLLMRequestAttributes(_ SpanHandle, _ *BifrostRequest) {} // PopulateLLMResponseAttributes does nothing. -func (n *NoOpTracer) PopulateLLMResponseAttributes(_ SpanHandle, _ *BifrostResponse, _ *BifrostError) { +func (n *NoOpTracer) PopulateLLMResponseAttributes(_ *BifrostContext, _ SpanHandle, _ *BifrostResponse, _ *BifrostError) { } // StoreDeferredSpan does nothing. @@ -176,6 +186,12 @@ func (n *NoOpTracer) ProcessStreamingChunk(_ string, _ bool, _ *BifrostResponse, return nil } +// AttachPluginLogs does nothing. +func (n *NoOpTracer) AttachPluginLogs(_ string, _ []PluginLogEntry) {} + +// CompleteAndFlushTrace does nothing. +func (n *NoOpTracer) CompleteAndFlushTrace(_ string) {} + // Stop does nothing. func (n *NoOpTracer) Stop() {} diff --git a/core/schemas/transcriptions.go b/core/schemas/transcriptions.go index 7308714ed5..1cd801be98 100644 --- a/core/schemas/transcriptions.go +++ b/core/schemas/transcriptions.go @@ -14,15 +14,37 @@ func (r *BifrostTranscriptionRequest) GetRawRequestBody() []byte { } type BifrostTranscriptionResponse struct { - Duration *float64 `json:"duration,omitempty"` // Duration in seconds - Language *string `json:"language,omitempty"` // e.g., "english" - LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"` - Segments []TranscriptionSegment `json:"segments,omitempty"` - Task *string `json:"task,omitempty"` // e.g., "transcribe" - Text string `json:"text"` - Usage *TranscriptionUsage `json:"usage,omitempty"` - Words []TranscriptionWord `json:"words,omitempty"` - ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + Duration *float64 `json:"duration,omitempty"` // Duration in seconds + Language *string `json:"language,omitempty"` // e.g., "english" + LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"` + Segments []TranscriptionSegment `json:"segments,omitempty"` + Task *string `json:"task,omitempty"` // e.g., "transcribe" + Text string `json:"text"` + Usage *TranscriptionUsage `json:"usage,omitempty"` + Words []TranscriptionWord `json:"words,omitempty"` + ResponseFormat *string `json:"-"` // Set by provider for non-JSON formats (text, srt, vtt); used by integration response converters + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` +} + +func (r *BifrostTranscriptionResponse) BackfillParams(req *BifrostTranscriptionRequest) { + if r == nil || req == nil || req.Params == nil || req.Params.ResponseFormat == nil { + return + } + r.ResponseFormat = req.Params.ResponseFormat +} + +// IsPlainTextTranscriptionFormat returns true if the given response format +// produces a plain-text response body (not JSON). +func IsPlainTextTranscriptionFormat(format *string) bool { + if format == nil { + return false + } + switch *format { + case "text", "srt", "vtt": + return true + default: + return false + } } type TranscriptionInput struct { @@ -31,17 +53,17 @@ type TranscriptionInput struct { } type TranscriptionParameters struct { - Language *string `json:"language,omitempty"` - Prompt *string `json:"prompt,omitempty"` - ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" - Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0) - TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json - Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs) - Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini - MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace - MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace - MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace - MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace + Language *string `json:"language,omitempty"` + Prompt *string `json:"prompt,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" + Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0) + TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json + Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs) + Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini + MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace + MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace + MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace + MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace // Elevenlabs-specific fields AdditionalFormats []TranscriptionAdditionalFormat `json:"additional_formats,omitempty"` @@ -132,4 +154,3 @@ type BifrostTranscriptionStreamResponse struct { Usage *TranscriptionUsage `json:"usage,omitempty"` ExtraFields BifrostResponseExtraFields `json:"extra_fields"` } - diff --git a/core/utils.go b/core/utils.go index ed8f40ebf4..0e1b45fefd 100644 --- a/core/utils.go +++ b/core/utils.go @@ -11,6 +11,7 @@ import ( "math/rand" "net" "net/url" + "slices" "strings" "time" @@ -86,19 +87,19 @@ func Ptr[T any](v T) *T { } // providerRequiresKey returns true if the given provider requires an API key for authentication. -// Some providers like Ollama, SGL, and vLLM are keyless and don't require API keys. -func providerRequiresKey(providerKey schemas.ModelProvider, customConfig *schemas.CustomProviderConfig) bool { +func providerRequiresKey(customConfig *schemas.CustomProviderConfig) bool { // Keyless custom providers are not allowed for Bedrock. if customConfig != nil && customConfig.IsKeyLess && customConfig.BaseProviderType != schemas.Bedrock { return false } - return !IsKeylessProvider(providerKey) + return true } -// canProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. -// Some providers like Vertex and Bedrock have their credentials in additional key configs.. +// CanProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. +// Some providers like Vertex and Bedrock have their credentials in additional key configs. +// Ollama and SGL are keyless (API Key is optional) but use per-key server URLs. func CanProviderKeyValueBeEmpty(providerKey schemas.ModelProvider) bool { - return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure + return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure || providerKey == schemas.Ollama || providerKey == schemas.SGL } func isKeySkippingAllowed(providerKey schemas.ModelProvider) bool { @@ -131,6 +132,56 @@ func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError { return nil } +// validateKey validates the given key. +func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) error { + // Validate the key for the provider + switch providerKey { + case schemas.Azure: + if key.AzureKeyConfig == nil { + return fmt.Errorf("azure_key_config is required") + } + if key.AzureKeyConfig.Endpoint.GetValue() == "" { + return fmt.Errorf("azure_key_config.endpoint is required") + } + case schemas.Bedrock: + // Key is valid if either: + // 1. BedrockKeyConfig is provided + // 2. Value is provided and is not empty + if key.BedrockKeyConfig == nil { + if key.Value.GetValue() == "" { + return fmt.Errorf("either value in key or bedrock_key_config is required") + } + key.BedrockKeyConfig = &schemas.BedrockKeyConfig{} + } + case schemas.Vertex: + if key.VertexKeyConfig == nil { + return fmt.Errorf("vertex_key_config is required") + } + case schemas.VLLM: + if key.VLLMKeyConfig == nil { + return fmt.Errorf("vllm_key_config is required") + } + if key.VLLMKeyConfig.URL.GetValue() == "" { + return fmt.Errorf("vllm_key_config.url is required") + } + case schemas.Ollama: + if key.OllamaKeyConfig == nil { + return fmt.Errorf("ollama_key_config is required") + } + if key.OllamaKeyConfig.URL.GetValue() == "" { + return fmt.Errorf("ollama_key_config.url is required") + } + case schemas.SGL: + if key.SGLKeyConfig == nil { + return fmt.Errorf("sgl_key_config is required") + } + if key.SGLKeyConfig.URL.GetValue() == "" { + return fmt.Errorf("sgl_key_config.url is required") + } + } + return nil +} + // IsRateLimitErrorMessage checks if an error message indicates a rate limit issue func IsRateLimitErrorMessage(errorMessage string) bool { if errorMessage == "" { @@ -175,7 +226,7 @@ func newBifrostErrorFromMsg(message string) *schemas.BifrostError { // newBifrostCtxDoneError creates a BifrostError from a cancelled/expired context. // It distinguishes DeadlineExceeded (504 RequestTimedOut) from Canceled (499 RequestCancelled). -func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, requestType schemas.RequestType, stage string) *schemas.BifrostError { +func newBifrostCtxDoneError(ctx *schemas.BifrostContext, stage string) *schemas.BifrostError { var statusCode int var errorType string var message string @@ -199,11 +250,6 @@ func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelP Message: message, Error: ctx.Err(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: provider, - ModelRequested: model, - }, } } @@ -230,6 +276,9 @@ func newBifrostMessageChan(message *schemas.BifrostResponse) chan *schemas.Bifro func clearCtxForFallback(ctx *schemas.BifrostContext) { ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID) ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName) + ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys) + ctx.ClearValue(schemas.BifrostContextKeyChangeRequestType) + ctx.ClearValue(schemas.BifrostContextKeyAttemptTrail) } var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} { @@ -261,11 +310,6 @@ func IsStandardProvider(providerKey schemas.ModelProvider) bool { return ok } -// IsKeylessProvider reports whether providerKey is a keyless provider. -func IsKeylessProvider(providerKey schemas.ModelProvider) bool { - return providerKey == schemas.Ollama || providerKey == schemas.SGL -} - // IsStreamRequestType returns true if the given request type is a stream request. func IsStreamRequestType(reqType schemas.RequestType) bool { return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.ImageGenerationStreamRequest || reqType == schemas.ImageEditStreamRequest || reqType == schemas.PassthroughStreamRequest || reqType == schemas.WebSocketResponsesRequest || reqType == schemas.RealtimeRequest @@ -336,14 +380,14 @@ func IsFinalChunk(ctx *schemas.BifrostContext) bool { return false } -// GetResponseFields extracts the request type, provider, and model from the result or error -func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, model string) { +// GetResponseFields extracts the request type, provider, original model, and resolved model from the result or error. +func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, originalModel string, resolvedModel string) { if result != nil { extraFields := result.GetExtraFields() - return extraFields.RequestType, extraFields.Provider, extraFields.ModelRequested + return extraFields.RequestType, extraFields.Provider, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed } if err != nil { - return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.ModelRequested + return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.OriginalModelRequested, err.ExtraFields.ResolvedModelUsed } return } @@ -366,7 +410,9 @@ func GetErrorMessage(err *schemas.BifrostError) string { if err == nil { return "" } - if err.StatusCode != nil { + if err.Error != nil && err.Error.Message != "" { + return err.Error.Message + } else if err.StatusCode != nil { switch *err.StatusCode { case 401: return "unauthorized" @@ -392,8 +438,6 @@ func GetErrorMessage(err *schemas.BifrostError) string { } return fmt.Sprintf("HTTP %d error", *err.StatusCode) } - } else if err.Error != nil && err.Error.Message != "" { - return err.Error.Message } else if err.Type != nil { return *err.Type } else { @@ -544,3 +588,44 @@ func buildSessionKey(providerKey schemas.ModelProvider, sessionID string, model } return "session:" + string(providerKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator) } + +// isPromptOptionalImageEditType returns true for edit task types that do not require a text prompt. +// It normalises hyphenated variants (e.g. "erase-object") to underscore form before matching. +func isPromptOptionalImageEditType(t *string) bool { + if t == nil { + return false + } + normalized := strings.ToLower(strings.TrimSpace(*t)) + normalized = strings.ReplaceAll(normalized, "-", "_") + return slices.Contains( + []string{"background_removal", "remove_background", "remove_bg", "erase_object", "upscale_fast"}, + normalized, + ) +} + +// wrapConvertedStreamPostHookRunner wraps a PostHookRunner so that streaming +// responses produced by a type-converted request are converted back to the +// caller's original type before the post-hook runs. +func wrapConvertedStreamPostHookRunner(postHookRunner schemas.PostHookRunner, targetType schemas.RequestType) schemas.PostHookRunner { + return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + if result != nil { + switch targetType { + case schemas.ChatCompletionRequest: + // text→chat: convert chat stream chunk back to text completion + if result.ChatResponse != nil { + if converted := result.ChatResponse.ToBifrostTextCompletionResponse(); converted != nil { + result = &schemas.BifrostResponse{TextCompletionResponse: converted} + } + } + case schemas.ResponsesRequest: + // chat→responses: convert responses stream chunk back to chat + if result.ResponsesStreamResponse != nil { + if converted := result.ResponsesStreamResponse.ToBifrostChatResponse(); converted != nil { + result = &schemas.BifrostResponse{ChatResponse: converted} + } + } + } + } + return postHookRunner(ctx, result, bifrostErr) + } +} diff --git a/core/version b/core/version index 3e06e400e9..4cda8f19ed 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.4.22 \ No newline at end of file +1.5.2 diff --git a/docs/contributing/setting-up-repo.mdx b/docs/contributing/setting-up-repo.mdx index d4ea991fb0..a787c6ba19 100644 --- a/docs/contributing/setting-up-repo.mdx +++ b/docs/contributing/setting-up-repo.mdx @@ -91,6 +91,16 @@ This command will: The `make dev` command handles all setup automatically. You can skip the manual setup steps below if this works for you. +#### Alternative: Using Pulse + +If you prefer [Pulse](https://github.com/Pratham-Mishra04/pulse) over Air for hot reloading, use: + +```bash +make dev-pulse +``` + +This runs the same development environment but uses `pulse.yaml` for hot reloading instead of `.air.toml`. + ### Manual Setup (Alternative) If you prefer to set up components manually: @@ -154,7 +164,8 @@ The Makefile provides numerous commands for development: ### Development Commands ```bash -make dev # Start complete development environment (recommended) +make dev # Start complete development environment using Air for hot reloading +make dev-pulse # Start complete development environment using Pulse for hot reloading make build # Build UI and bifrost-http binary make run # Build and run (no hot reload) make clean # Clean build artifacts @@ -373,6 +384,9 @@ which air || go install github.com/air-verse/air@latest # Check if .air.toml exists in transports/bifrost-http/ ls transports/bifrost-http/.air.toml + +# Alternatively, use Pulse instead of Air +make dev-pulse ``` ### Getting Help diff --git a/docs/deployment-guides/config-json.mdx b/docs/deployment-guides/config-json.mdx new file mode 100644 index 0000000000..413c0e5a51 --- /dev/null +++ b/docs/deployment-guides/config-json.mdx @@ -0,0 +1,313 @@ +--- +title: "Quick Start" +description: "Configure Bifrost using a config.json file — GitOps-friendly, no-UI deployments, and multinode OSS setups" +icon: "file-code" +--- + + +**Full schema reference:** [`https://www.getbifrost.ai/schema`](https://www.getbifrost.ai/schema) + + +`config.json` lets you configure every aspect of Bifrost through a single declarative file. It is the right choice for GitOps workflows, CI/CD pipelines, headless deployments, and multinode OSS setups where a central configuration file is shared across all replicas. + +--- + +## Two Configuration Modes + +Bifrost supports **two mutually exclusive modes**. You cannot run both at the same time. + +| Mode | When | Behaviour | +|------|------|-----------| +| **Web UI / database** | No `config.json`, or `config.json` with `config_store` enabled | Full UI available, configuration stored in SQLite or PostgreSQL | +| **File-based (`config.json`)** | `config.json` present, `config_store` disabled | UI disabled, all config loaded from file at startup, restart required for changes | + + +See [Setting Up](/quickstart/gateway/setting-up#two-configuration-modes) for a full explanation of both modes and how `config_store` bootstrapping works. + + +--- + +## Minimal Working Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + "client": { + "drop_excess_requests": false, + "enable_logging": true + }, + "providers": { + "openai": { + "keys": [ + { + "name": "openai-primary", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + } + }, + "config_store": { + "enabled": false + } +} +``` + +Save this as `config.json` in your app directory and start Bifrost: + +```bash +# NPX +npx -y @maximhq/bifrost -app-dir ./data + +# Docker +docker run -p 8080:8080 \ + -v $(pwd)/data:/app/data \ + -e OPENAI_API_KEY=sk-... \ + -e BIFROST_ENCRYPTION_KEY=your-32-byte-key \ + maximhq/bifrost +``` + +Make your first call: + +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +--- + +## Environment Variable References + +Never put secrets directly in `config.json`. Use the `env.` prefix to reference any environment variable: + +```json +{ + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + "providers": { + "openai": { + "keys": [ + { + "name": "primary", + "value": "env.OPENAI_API_KEY", + "weight": 1.0 + } + ] + } + } +} +``` + +Set the actual values through your deployment platform — shell environment, Docker `-e`, Kubernetes Secrets mounted as env vars, or a `.env` file. + +--- + +## Schema Validation + +Add `$schema` to every `config.json` for IDE autocomplete and inline validation: + +```json +{ + "$schema": "https://www.getbifrost.ai/schema" +} +``` + +Editors (VS Code, JetBrains, Neovim with LSP) will show completions and flag invalid fields as you type. + +--- + +## Production Example + +A production-ready file with PostgreSQL storage, multi-provider setup, governance, and common plugins: + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "client": { + "initial_pool_size": 500, + "drop_excess_requests": true, + "enable_logging": true, + "log_retention_days": 90, + "enforce_auth_on_inference": true, + "allow_direct_keys": false, + "allowed_origins": ["https://app.yourcompany.com"] + }, + + "providers": { + "openai": { + "keys": [ + { + "name": "openai-primary", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 120, + "max_retries": 3 + } + }, + "anthropic": { + "keys": [ + { + "name": "anthropic-primary", + "value": "env.ANTHROPIC_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + } + }, + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require" + } + }, + + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require" + } + } +} +``` + +--- + +## Example Configs + +Ready-to-use reference configurations from the [examples/configs](https://github.com/maximhq/bifrost/tree/main/examples/configs) directory on GitHub: + + + + + +| Example | Description | +|---------|-------------| +| [noconfigstorenologstore](https://github.com/maximhq/bifrost/blob/main/examples/configs/noconfigstorenologstore/config.json) | Bare-minimum file-only mode — no database, no UI, providers loaded from file | +| [partial](https://github.com/maximhq/bifrost/blob/main/examples/configs/partial/config.json) | SQLite config store with a minimal provider setup | +| [v1compat](https://github.com/maximhq/bifrost/blob/main/examples/configs/v1compat/config.json) | `"version": 1` for v1.4.x array semantics (empty = allow all) | + + + + + +| Example | Description | +|---------|-------------| +| [withconfigstore](https://github.com/maximhq/bifrost/blob/main/examples/configs/withconfigstore/config.json) | SQLite config store (Web UI enabled) | +| [withconfigstorelogsstorepostgres](https://github.com/maximhq/bifrost/blob/main/examples/configs/withconfigstorelogsstorepostgres/config.json) | PostgreSQL for both config store and logs store | +| [withlogstore](https://github.com/maximhq/bifrost/blob/main/examples/configs/withlogstore/config.json) | SQLite logs store | +| [withobjectstorages3](https://github.com/maximhq/bifrost/blob/main/examples/configs/withobjectstorages3/config.json) | S3 object storage offload for logs | +| [withobjectstoragegcs](https://github.com/maximhq/bifrost/blob/main/examples/configs/withobjectstoragegcs/config.json) | GCS object storage offload for logs | +| [withvectorstoreweaviate](https://github.com/maximhq/bifrost/blob/main/examples/configs/withvectorstoreweaviate/config.json) | Weaviate vector store (with [docker-compose](https://github.com/maximhq/bifrost/blob/main/examples/configs/withvectorstoreweaviate/docker-compose.yml)) | + + + + + +| Example | Description | +|---------|-------------| +| [withsemanticcache](https://github.com/maximhq/bifrost/blob/main/examples/configs/withsemanticcache/config.json) | Semantic cache backed by Weaviate | +| [withsemanticcachevalkey](https://github.com/maximhq/bifrost/blob/main/examples/configs/withsemanticcachevalkey/config.json) | Semantic cache backed by Valkey / Redis | + + + + + +| Example | Description | +|---------|-------------| +| [withauth](https://github.com/maximhq/bifrost/blob/main/examples/configs/withauth/config.json) | Admin username/password auth (`governance.auth_config`) | +| [withvirtualkeys](https://github.com/maximhq/bifrost/blob/main/examples/configs/withvirtualkeys/config.json) | Virtual keys with provider/model allowlists | +| [withteamscustomers](https://github.com/maximhq/bifrost/blob/main/examples/configs/withteamscustomers/config.json) | Teams and customers with budgets and rate limits | +| [withroutingrules](https://github.com/maximhq/bifrost/blob/main/examples/configs/withroutingrules/config.json) | CEL-based routing rules for dynamic provider/model selection | +| [withpricingoverridesnostore](https://github.com/maximhq/bifrost/blob/main/examples/configs/withpricingoverridesnostore/config.json) | Pricing overrides in file-only mode | +| [withpricingoverridessqlite](https://github.com/maximhq/bifrost/blob/main/examples/configs/withpricingoverridessqlite/config.json) | Pricing overrides with SQLite config store | + + + + + +| Example | Description | +|---------|-------------| +| [withobservability](https://github.com/maximhq/bifrost/blob/main/examples/configs/withobservability/config.json) | Prometheus metrics (telemetry always active, custom labels via `client.prometheus_labels`) | +| [withprompushgateway](https://github.com/maximhq/bifrost/blob/main/examples/configs/withprompushgateway/config.json) | Prometheus Push Gateway for multi-instance deployments | +| [withotel](https://github.com/maximhq/bifrost/blob/main/examples/configs/withotel/config.json) | OpenTelemetry traces and metrics | + + + + + +| Example | Description | +|---------|-------------| +| [withdynamicplugin](https://github.com/maximhq/bifrost/blob/main/examples/configs/withdynamicplugin/config.json) | Loading a custom `.so` plugin at startup | +| [withcompat](https://github.com/maximhq/bifrost/blob/main/examples/configs/withcompat/config.json) | SDK compatibility shims (`should_drop_params`, `convert_text_to_chat`) | +| [withframework](https://github.com/maximhq/bifrost/blob/main/examples/configs/withframework/config.json) | Custom model pricing catalog URL and sync interval | +| [withlargepayload](https://github.com/maximhq/bifrost/blob/main/examples/configs/withlargepayload/config.json) | Large payload optimization (streaming without full materialisation) | +| [withwebsocket](https://github.com/maximhq/bifrost/blob/main/examples/configs/withwebsocket/config.json) | WebSocket / Realtime API connection pool tuning | +| [withpostgresmcpclientsinconfig](https://github.com/maximhq/bifrost/blob/main/examples/configs/withpostgresmcpclientsinconfig/config.json) | MCP client definitions seeded from config.json with PostgreSQL store | +| [encryptionmigration](https://github.com/maximhq/bifrost/blob/main/examples/configs/encryptionmigration/config.json) | Migrating to a new encryption key | + + + + + +--- + +## Configuration Guides + + + + Every top-level key, its type, default, and where it is documented + + + Pool size, logging, CORS, header filtering, compat shims, MCP settings + + + OpenAI, Anthropic, Azure, Bedrock, Vertex, Groq, self-hosted + + + config_store, logs_store, vector_store — SQLite, PostgreSQL, object storage + + + Semantic cache, OTel, Maxim, Datadog, custom plugins + + + Virtual keys, budgets, rate limits, routing rules, admin auth + + + Content moderation providers and CEL-based rules (enterprise) + + + +--- + +## Next Steps + +1. Configure [provider keys](/providers/supported-providers/overview) +2. Enable [plugins](/plugins/getting-started) +3. Set up [observability](/features/observability/default) +4. Configure [governance](/features/governance/virtual-keys) +5. Deploy [multiple nodes](/deployment-guides/how-to/multinode) with a shared `config.json` diff --git a/docs/deployment-guides/config-json/client.mdx b/docs/deployment-guides/config-json/client.mdx new file mode 100644 index 0000000000..1a974df77b --- /dev/null +++ b/docs/deployment-guides/config-json/client.mdx @@ -0,0 +1,276 @@ +--- +title: "Client Configuration" +description: "Configure the Bifrost client in config.json — connection pool, logging, CORS, header filtering, compat shims, and MCP settings" +icon: "gear" +--- + +The `client` block controls how Bifrost manages its internal worker pool, request logging, authentication enforcement, header policies, SDK compatibility shims, and MCP agent behaviour. + +--- + +## Connection Pool + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `initial_pool_size` | integer | `300` | Pre-allocated worker goroutines per provider queue | +| `drop_excess_requests` | boolean | `false` | Drop requests when queue is full instead of waiting (returns HTTP 429) | + +A larger pool reduces latency spikes under burst load at the cost of higher baseline memory. `500–1000` is a common starting point for production workloads with multiple providers. + +```json +{ + "client": { + "initial_pool_size": 1000, + "drop_excess_requests": true + } +} +``` + +--- + +## Request & Response Logging + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `enable_logging` | boolean | — | Log all LLM requests and responses | +| `disable_content_logging` | boolean | `false` | Strip message content from logs (keeps metadata only) | +| `log_retention_days` | integer | `365` | Days to retain log entries in the store | +| `logging_headers` | array of strings | `[]` | HTTP request headers to capture in log metadata | + +Set `disable_content_logging: true` for HIPAA / PCI compliance workloads where message content must not be persisted. + +```json +{ + "client": { + "enable_logging": true, + "disable_content_logging": true, + "log_retention_days": 90, + "logging_headers": ["x-request-id", "x-user-id"] + } +} +``` + +--- + +## Security & CORS + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `allowed_origins` | array | `["*"]` | CORS allowed origins (use URIs or `"*"`) | +| `allow_direct_keys` | boolean | `false` | Allow callers to pass provider keys directly in requests | +| `enforce_auth_on_inference` | boolean | `false` | Require auth (virtual key, API key, or user token) on `/v1/*` inference routes | +| `max_request_body_size_mb` | integer | `100` | Maximum allowed request body size in MB | +| `whitelisted_routes` | array of strings | `[]` | Routes that bypass auth middleware | +| `allowed_headers` | array of strings | `[]` | Additional headers permitted for CORS and WebSocket | + +```json +{ + "client": { + "allowed_origins": [ + "https://app.yourcompany.com", + "https://admin.yourcompany.com" + ], + "allow_direct_keys": false, + "enforce_auth_on_inference": true, + "max_request_body_size_mb": 50, + "whitelisted_routes": ["/health", "/metrics"] + } +} +``` + +--- + +## Header Filtering + +Controls which `x-bf-eh-*` extra headers are forwarded to upstream LLM providers. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `header_filter_config.allowlist` | array of strings | `[]` | Only these headers are forwarded (whitelist mode) | +| `header_filter_config.denylist` | array of strings | `[]` | These headers are always blocked | +| `required_headers` | array of strings | `[]` | Headers that must be present on every request (rejected with 400 if missing) | + +When both `allowlist` and `denylist` are empty, all `x-bf-eh-*` headers pass through. Specifying an `allowlist` enables strict whitelist mode — only listed headers are forwarded. + +```json +{ + "client": { + "header_filter_config": { + "allowlist": [ + "x-bf-eh-anthropic-version", + "x-bf-eh-openai-beta" + ], + "denylist": [] + }, + "required_headers": ["x-request-id"] + } +} +``` + +--- + +## Compat Shims + +Compatibility flags that let Bifrost silently adapt request/response shapes for SDK integrations. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `compat.convert_text_to_chat` | boolean | `false` | Wrap legacy `/v1/completions` text requests as chat messages | +| `compat.convert_chat_to_responses` | boolean | `false` | Translate chat completions to Responses API format | +| `compat.should_drop_params` | boolean | `false` | Silently drop unsupported parameters instead of erroring | +| `compat.should_convert_params` | boolean | `false` | Auto-convert parameter values across provider schemas | + +```json +{ + "client": { + "compat": { + "should_drop_params": true, + "convert_text_to_chat": true + } + } +} +``` + +--- + +## MCP Agent Settings + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `mcp_agent_depth` | integer | `10` | Maximum tool-call recursion depth for MCP agent mode | +| `mcp_tool_execution_timeout` | integer | `30` | Timeout per MCP tool execution in seconds | +| `mcp_code_mode_binding_level` | string | — | Code mode binding level: `"server"` or `"tool"` | +| `mcp_tool_sync_interval` | integer | `10` | Global tool sync interval in minutes (`0` = disabled) | +| `mcp_disable_auto_tool_inject` | boolean | `false` | When `true`, MCP tools are not automatically injected into requests | + +```json +{ + "client": { + "mcp_agent_depth": 15, + "mcp_tool_execution_timeout": 60, + "mcp_tool_sync_interval": 10 + } +} +``` + +--- + +## Async Jobs + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `async_job_result_ttl` | integer | `3600` | TTL (seconds) for async job results | +| `disable_db_pings_in_health` | boolean | `false` | Exclude database connectivity from `/health` endpoint checks | + +--- + +## Prometheus Labels + +Add custom labels to every Prometheus metric emitted by Bifrost: + +```json +{ + "client": { + "prometheus_labels": ["environment=production", "region=us-east-1"] + } +} +``` + +--- + +## Authentication + +`governance.auth_config` protects the Bifrost dashboard and management API with username/password auth. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `is_enabled` | boolean | `false` | Enable username/password auth | +| `admin_username` | string | — | Admin username | +| `admin_password` | string | — | Admin password (use `env.` reference) | +| `disable_auth_on_inference` | boolean | `false` | Skip auth check on `/v1/*` inference routes | + +```json +{ + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD", + "disable_auth_on_inference": false + } + } +} +``` + + +A top-level `auth_config` is also accepted for backwards compatibility, but `governance.auth_config` is the preferred location. + + +--- + +## Encryption Key + +```json +{ + "encryption_key": "env.BIFROST_ENCRYPTION_KEY" +} +``` + +| Notes | +|-------| +| Accepts any string; Bifrost derives a 32-byte AES-256 key using Argon2id | +| Can also be set via the `BIFROST_ENCRYPTION_KEY` environment variable | +| Once set and the database is populated, the key cannot be changed without clearing the database | +| Omitting the key stores data in plain text — not recommended for production | + +--- + +## Full Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD", + "disable_auth_on_inference": false + } + }, + + "client": { + "initial_pool_size": 1000, + "drop_excess_requests": true, + + "enable_logging": true, + "disable_content_logging": false, + "log_retention_days": 90, + "logging_headers": ["x-request-id", "x-user-id"], + + "allowed_origins": ["https://app.yourcompany.com"], + "allow_direct_keys": false, + "enforce_auth_on_inference": true, + "max_request_body_size_mb": 100, + + "header_filter_config": { + "allowlist": [], + "denylist": [] + }, + "required_headers": [], + + "compat": { + "should_drop_params": false + }, + + "prometheus_labels": ["environment=production"], + + "mcp_agent_depth": 10, + "mcp_tool_execution_timeout": 30, + + "async_job_result_ttl": 3600 + } +} +``` diff --git a/docs/deployment-guides/config-json/governance.mdx b/docs/deployment-guides/config-json/governance.mdx new file mode 100644 index 0000000000..16ed48115e --- /dev/null +++ b/docs/deployment-guides/config-json/governance.mdx @@ -0,0 +1,333 @@ +--- +title: "Governance" +description: "Seed virtual keys, budgets, rate limits, routing rules, and admin auth in config.json" +icon: "shield-check" +--- + +The `governance` block lets you seed all governance resources directly in `config.json`. On startup, Bifrost loads these into the configuration store. This is the recommended approach for GitOps workflows where governance state is managed as code. + + +**Governance enforcement is always active** in OSS — you do not need a plugin entry to enable it. To require a virtual key on every inference request, set `client.enforce_auth_on_inference: true`. This is the global default, but a more specific inference-auth flag such as `governance.auth_config.disable_auth_on_inference` overrides it; if no specific override is set, `client.enforce_auth_on_inference` applies. + + +--- + +## Admin Authentication + +Protect the Bifrost dashboard and management API with username/password auth: + +```json +{ + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD", + "disable_auth_on_inference": false + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `is_enabled` | `false` | Enable admin username/password auth | +| `admin_username` | — | Admin username (supports `env.` prefix) | +| `admin_password` | — | Admin password (supports `env.` prefix) | +| `disable_auth_on_inference` | `false` | Skip auth check on `/v1/*` inference routes | + +--- + +## Virtual Keys + +Virtual keys are issued to clients and act as scoped API tokens. Each key specifies which providers, models, and API keys the bearer is allowed to use. + +```json +{ + "governance": { + "virtual_keys": [ + { + "id": "vk-team-platform", + "name": "platform-team", + "value": "env.VK_PLATFORM_TEAM", + "is_active": true, + "provider_configs": [ + { + "provider": "openai", + "allowed_models": ["gpt-4o", "gpt-4o-mini"], + "key_ids": ["*"], + "weight": 1 + }, + { + "provider": "anthropic", + "allowed_models": ["*"], + "key_ids": ["*"], + "weight": 1 + } + ] + } + ] + } +} +``` + +### Virtual Key Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique virtual key ID (referenced by budgets / rate limits) | +| `name` | Yes | Human-readable name | +| `value` | No | The key token sent by clients (use `env.` prefix). Auto-generated if omitted | +| `is_active` | No | Default `true`. Set `false` to disable without deleting | +| `team_id` | No | Associate with a team (mutually exclusive with `customer_id`) | +| `customer_id` | No | Associate with a customer | +| `rate_limit_id` | No | Attach a rate limit | +| `calendar_aligned` | No | Snap budget resets to day/week/month/year boundaries | +| `provider_configs` | No | Allowed provider/model/key combinations (empty = deny all) | + +### Provider Config Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `provider` | Yes | Provider name (e.g. `"openai"`) | +| `allowed_models` | No | Model allow-list. `["*"]` = all models; `[]` = deny all | +| `key_ids` | No | Provider key names allowed for this VK. `["*"]` = all keys; `[]` = deny all. Use key `name` values (not UUIDs) in `config.json` | +| `weight` | No | Load-balancing weight when multiple provider configs are present | +| `rate_limit_id` | No | Attach a per-provider-config rate limit | + +--- + +## Budgets + +Budgets cap cumulative spend (in USD) for a virtual key or provider config over a rolling window: + +```json +{ + "governance": { + "budgets": [ + { + "id": "budget-platform-monthly", + "max_limit": 500.00, + "reset_duration": "1M", + "virtual_key_id": "vk-team-platform" + } + ] + } +} +``` + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique budget ID | +| `max_limit` | Yes | Maximum spend in USD | +| `reset_duration` | Yes | Window length: `"30s"`, `"5m"`, `"1h"`, `"1d"`, `"1w"`, `"1M"`, `"1Y"` | +| `virtual_key_id` | No | Attach to a virtual key (mutually exclusive with `provider_config_id`) | +| `provider_config_id` | No | Attach to a provider config ID | + +--- + +## Rate Limits + +Rate limits cap requests or tokens over a rolling window: + +```json +{ + "governance": { + "rate_limits": [ + { + "id": "rl-platform-hourly", + "request_max_limit": 1000, + "request_reset_duration": "1h", + "token_max_limit": 1000000, + "token_reset_duration": "1h" + } + ] + } +} +``` + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique rate limit ID | +| `request_max_limit` | No | Maximum requests in window | +| `request_reset_duration` | No | Window for request counter | +| `token_max_limit` | No | Maximum tokens (input + output) in window | +| `token_reset_duration` | No | Window for token counter | + +Attach a rate limit to a virtual key via `virtual_keys[].rate_limit_id`, or to a provider config via `virtual_keys[].provider_configs[].rate_limit_id`. + +--- + +## Routing Rules + +Routing rules dynamically select the provider and model for each request based on a [CEL](https://cel.dev) expression. They are evaluated in priority order before the request is dispatched. + +```json +{ + "governance": { + "routing_rules": [ + { + "id": "route-gpt4-to-azure", + "name": "Redirect GPT-4o to Azure", + "cel_expression": "request.model == 'gpt-4o'", + "targets": [ + { "provider": "azure", "model": "gpt-4o", "weight": 1.0 } + ] + }, + { + "id": "route-cost-split", + "name": "Split traffic 70/30 between providers", + "cel_expression": "true", + "targets": [ + { "provider": "openai", "weight": 0.7 }, + { "provider": "anthropic", "weight": 0.3 } + ] + } + ] + } +} +``` + +### Rule Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique rule ID | +| `name` | Yes | Human-readable name | +| `cel_expression` | No | CEL expression. `"true"` matches every request | +| `targets` | Yes | Weighted target list (weights must sum to `1.0`) | +| `enabled` | No | Default `true` | +| `priority` | No | Evaluation order within scope — lower numbers run first | +| `scope` | No | `"global"` (default), `"team"`, `"customer"`, `"virtual_key"` | +| `scope_id` | Conditional | Required when `scope` is not `"global"` | +| `chain_rule` | No | If `true`, re-evaluates the chain after this rule matches | +| `fallbacks` | No | Ordered fallback provider list if primary target fails | + +### Target Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `weight` | Yes | Fraction of traffic (all weights in a rule must sum to `1.0`) | +| `provider` | No | Target provider. Omit to keep the incoming request's provider | +| `model` | No | Target model. Omit to keep the incoming request's model | +| `key_id` | No | Pin a specific API key by name | + +--- + +## Customers & Teams + +Define organizational entities and attach budgets or rate limits to them: + +```json +{ + "governance": { + "customers": [ + { + "id": "customer-acme", + "name": "Acme Corp", + "budget_id": "budget-acme-monthly", + "rate_limit_id": "rl-acme-hourly" + } + ], + "teams": [ + { + "id": "team-ml", + "name": "ML Team", + "customer_id": "customer-acme", + "budget_id": "budget-team-ml" + } + ] + } +} +``` + +--- + +## Full Governance Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "client": { + "enforce_auth_on_inference": true + }, + + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD" + }, + + "budgets": [ + { + "id": "budget-platform", + "max_limit": 1000.00, + "reset_duration": "1M", + "virtual_key_id": "vk-platform" + } + ], + + "rate_limits": [ + { + "id": "rl-platform", + "request_max_limit": 5000, + "request_reset_duration": "1h", + "token_max_limit": 5000000, + "token_reset_duration": "1h" + } + ], + + "virtual_keys": [ + { + "id": "vk-platform", + "name": "platform-key", + "value": "env.VK_PLATFORM", + "is_active": true, + "rate_limit_id": "rl-platform", + "provider_configs": [ + { + "provider": "openai", + "allowed_models": ["*"], + "key_ids": ["*"], + "weight": 1 + } + ] + } + ], + + "routing_rules": [ + { + "id": "fallback-to-anthropic", + "name": "Fallback on error", + "cel_expression": "true", + "targets": [{ "provider": "openai", "weight": 1.0 }], + "fallbacks": ["anthropic"] + } + ] + }, + + "providers": { + "openai": { + "keys": [{ "name": "openai-primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "anthropic": { + "keys": [{ "name": "anthropic-primary", "value": "env.ANTHROPIC_API_KEY", "models": ["*"], "weight": 1.0 }] + } + }, + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost" + } + } +} +``` diff --git a/docs/deployment-guides/config-json/guardrails.mdx b/docs/deployment-guides/config-json/guardrails.mdx new file mode 100644 index 0000000000..f6258ca872 --- /dev/null +++ b/docs/deployment-guides/config-json/guardrails.mdx @@ -0,0 +1,291 @@ +--- +title: "Guardrails" +description: "Configure content moderation and policy enforcement in config.json using guardrails_config" +icon: "shield-halved" +--- + + +Guardrails are an **enterprise-only** feature and require the enterprise Bifrost image. + + +Guardrails are configured under `guardrails_config` in `config.json`. The configuration has two parts: + +- **`guardrail_providers`** — the backend that performs the check. Rules link to providers by `id`. +- **`guardrail_rules`** — CEL expressions that control when and where providers are invoked. + +--- + +## Providers + + + + +Runs entirely in-process with no external dependency. Patterns use RE2 syntax. Supports optional per-pattern flags: `i` (case-insensitive), `m` (multiline), `s` (dot-all). + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 1, + "provider_name": "regex", + "policy_name": "block-secrets", + "enabled": true, + "timeout": 5, + "config": { + "patterns": [ + { "pattern": "sk-[A-Za-z0-9]{20,}", "description": "OpenAI API key" }, + { "pattern": "AKIA[0-9A-Z]{16}", "description": "AWS access key" }, + { "pattern": "gh[ps]_[A-Za-z0-9]{36}", "description": "GitHub token", "flags": "i" } + ], + "mode": "block" + } + } + ] + } +} +``` + + + + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 2, + "provider_name": "bedrock", + "policy_name": "content-filter", + "enabled": true, + "timeout": 15, + "config": { + "guardrail_arn": "arn:aws:bedrock:us-east-1::guardrail/abc123", + "guardrail_version": "DRAFT", + "region": "us-east-1", + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY" + } + } + ] + } +} +``` + + + + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 3, + "provider_name": "azure", + "policy_name": "azure-content-safety", + "enabled": true, + "timeout": 10, + "config": { + "endpoint": "https://your-resource.cognitiveservices.azure.com", + "api_key": "env.AZURE_CONTENT_SAFETY_KEY", + "analyze_enabled": true, + "analyze_severity_threshold": "medium", + "jailbreak_shield_enabled": true, + "indirect_attack_shield_enabled": true, + "copyright_enabled": false, + "text_blocklist_enabled": false, + "blocklist_names": [] + } + } + ] + } +} +``` + +`analyze_severity_threshold` accepts `"low"`, `"medium"`, or `"high"`. + + + + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 4, + "provider_name": "grayswan", + "policy_name": "grayswan-jailbreak", + "enabled": true, + "timeout": 15, + "config": { + "api_key": "env.GRAYSWAN_API_KEY", + "violation_threshold": 0.7, + "reasoning_mode": "standard", + "policy_id": "", + "policy_ids": [], + "rules": {} + } + } + ] + } +} +``` + + + + +### Provider Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique integer ID — referenced by rules via `provider_config_ids` | +| `provider_name` | Yes | Backend: `"regex"`, `"bedrock"`, `"azure"`, `"grayswan"` | +| `policy_name` | Yes | Human-readable policy label | +| `enabled` | Yes | `true` to activate | +| `timeout` | No | Execution timeout in seconds | +| `config` | No | Provider-specific configuration object | + +--- + +## Rules + +Rules are CEL expressions that fire when their condition matches. Available CEL variables: + +| Variable | Type | Description | +|----------|------|-------------| +| `model` | `string` | Model name from the request | +| `provider` | `string` | Provider name (e.g. `"openai"`) | +| `headers` | `map` | HTTP request headers | +| `params` | `map` | Query parameters | +| `customer` | `string` | Customer ID | +| `team` | `string` | Team ID | +| `user` | `string` | User ID | + +```json +{ + "guardrails_config": { + "guardrail_rules": [ + { + "id": 101, + "name": "block-secrets-input", + "description": "Block prompts containing credentials", + "enabled": true, + "cel_expression": "true", + "apply_to": "input", + "sampling_rate": 100, + "timeout": 10, + "provider_config_ids": [1] + }, + { + "id": 102, + "name": "content-safety-gpt4o-output", + "enabled": true, + "cel_expression": "model == 'gpt-4o'", + "apply_to": "output", + "sampling_rate": 100, + "timeout": 15, + "provider_config_ids": [3] + }, + { + "id": 103, + "name": "grayswan-openai-partial", + "enabled": true, + "cel_expression": "provider == 'openai'", + "apply_to": "input", + "sampling_rate": 50, + "timeout": 20, + "provider_config_ids": [4] + } + ] + } +} +``` + +### Rule Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique integer ID | +| `name` | Yes | Human-readable name | +| `description` | No | Optional description | +| `enabled` | Yes | `true` to activate | +| `cel_expression` | Yes | CEL boolean expression. `"true"` matches every request | +| `apply_to` | Yes | `"input"`, `"output"`, or `"both"` | +| `sampling_rate` | No | `0`–`100`; percentage of requests to evaluate (default: `100`) | +| `timeout` | No | Rule timeout in seconds | +| `provider_config_ids` | No | `id` values of providers to invoke when this rule matches. Multiple providers run in parallel | + +--- + +## Full Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "providers": { + "openai": { + "keys": [{ "name": "primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 }] + } + }, + + "guardrails_config": { + "guardrail_providers": [ + { + "id": 1, + "provider_name": "regex", + "policy_name": "block-secrets", + "enabled": true, + "timeout": 5, + "config": { + "patterns": [ + { "pattern": "sk-[A-Za-z0-9]{20,}", "description": "OpenAI API key" }, + { "pattern": "AKIA[0-9A-Z]{16}", "description": "AWS access key" } + ], + "mode": "block" + } + }, + { + "id": 2, + "provider_name": "azure", + "policy_name": "content-safety", + "enabled": true, + "timeout": 10, + "config": { + "endpoint": "https://your-resource.cognitiveservices.azure.com", + "api_key": "env.AZURE_CONTENT_SAFETY_KEY", + "analyze_enabled": true, + "analyze_severity_threshold": "medium", + "jailbreak_shield_enabled": true, + "indirect_attack_shield_enabled": false + } + } + ], + "guardrail_rules": [ + { + "id": 101, + "name": "block-secrets-input", + "description": "Block prompts leaking credentials", + "enabled": true, + "cel_expression": "true", + "apply_to": "input", + "sampling_rate": 100, + "timeout": 10, + "provider_config_ids": [1] + }, + { + "id": 102, + "name": "content-safety-both", + "description": "Azure content safety on all traffic", + "enabled": true, + "cel_expression": "true", + "apply_to": "both", + "sampling_rate": 100, + "timeout": 15, + "provider_config_ids": [2] + } + ] + } +} +``` diff --git a/docs/deployment-guides/config-json/plugins.mdx b/docs/deployment-guides/config-json/plugins.mdx new file mode 100644 index 0000000000..847f290e02 --- /dev/null +++ b/docs/deployment-guides/config-json/plugins.mdx @@ -0,0 +1,318 @@ +--- +title: "Plugins" +description: "Configure Bifrost plugins in config.json — semantic cache, OpenTelemetry, Maxim, Datadog, and custom plugins" +icon: "puzzle-piece" +--- + + +**The `plugins` array only controls explicitly opt-in plugins**: `semantic_cache`, `otel`, `maxim`, `datadog` (enterprise), and custom plugins. + +**Telemetry, logging, and governance are auto-loaded built-ins** — they are always active and configured via the `client` block and dedicated top-level keys, not the `plugins` array. + + +--- + +## Auto-Loaded Built-ins + +These plugins start automatically. You do **not** add them to the `plugins` array. + +| Plugin | Always active? | How to configure | +|--------|---------------|-----------------| +| **Telemetry** (Prometheus `/metrics`) | Yes, always | `client.prometheus_labels` for custom labels; push gateway via `plugins` entry once DB-backed mode is running | +| **Logging** | When `client.enable_logging: true` and `logs_store` is configured | `client.enable_logging`, `client.disable_content_logging`, `client.logging_headers` | +| **Governance** | Yes, always (OSS) | `client.enforce_auth_on_inference` for VK enforcement; `governance.*` for virtual keys / budgets / routing rules | + +See [Client Configuration](/deployment-guides/config-json/client) and [Governance](/deployment-guides/config-json/governance) for full details. + +--- + +## Plugin Array Structure + +Every entry in the `plugins` array supports these common fields: + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `name` | string | Yes | Plugin name | +| `enabled` | boolean | Yes | Enable or disable this plugin | +| `config` | object | Varies | Plugin-specific configuration | +| `path` | string | No | Path to a custom plugin binary or WASM file | +| `version` | integer | No | 🛑 **DB-Backed Only.** Plugin metadata persisted on `TablePlugin` rather than `PluginConfig`. Ignored in `config.json`. Used in UI/DB workflows to force refresh/reload. | +| `placement` | string | No | 🛑 **DB-Backed Only.** Execution metadata (`"pre_builtin"`, `"builtin"`, `"post_builtin"`) persisted on `TablePlugin`. Ignored in `config.json`. Relevant for dynamic plugin ordering in UI/DB mode. | +| `order` | integer | No | 🛑 **DB-Backed Only.** Execution metadata persisted on `TablePlugin`. Ignored in `config.json`. Within a placement group, lower values run earlier. | + + +`name`, `enabled`, `path`, and `config` are the core plugin config fields parsed from `config.json`. `version`, `placement`, and `order` are **not valid `config.json` keys**; they are DB-backed metadata persisted on `TablePlugin` and are only applicable when managing plugins dynamically via the UI or Database. + + +--- + + + + + +### Semantic Cache + +Caches LLM responses by semantic similarity. Returns a cached response when an incoming request is semantically close enough to a previous one. + +Requires a [vector store](/deployment-guides/config-json/storage#vector_store) to be configured. + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `config.dimension` | Yes | — | Embedding dimension. Use `1` for hash-based (exact) caching without an embedding provider | +| `config.provider` | No | — | Provider for generating embeddings (required for semantic mode) | +| `config.embedding_model` | No | — | Model for embeddings (required when `provider` is set) | +| `config.threshold` | No | `0.8` | Cosine similarity threshold for a cache hit (0.0–1.0) | +| `config.ttl` | No | `300` | Cache entry TTL in seconds (or a duration string like `"1h"`) | +| `config.cache_by_model` | No | `true` | Include model in cache key | +| `config.cache_by_provider` | No | `true` | Include provider in cache key | +| `config.exclude_system_prompt` | No | `false` | Exclude system prompt from cache key | +| `config.conversation_history_threshold` | No | `3` | Skip caching for requests with more messages than this | +| `config.default_cache_key` | No | — | Default cache key when no `x-bf-cache-key` header is sent | + +**Semantic mode** (embedding-based similarity search): + +```json +{ + "plugins": [ + { + "name": "semantic_cache", + "enabled": true, + "config": { + "provider": "openai", + "embedding_model": "text-embedding-3-small", + "dimension": 1536, + "threshold": 0.85, + "ttl": 300, + "cache_by_model": true, + "cache_by_provider": true + } + } + ] +} +``` + +**Hash mode** (exact-match caching, no embedding provider needed): + +```json +{ + "plugins": [ + { + "name": "semantic_cache", + "enabled": true, + "config": { + "dimension": 1, + "ttl": 1800 + } + } + ] +} +``` + + +You must also configure a `vector_store` in `config.json`. See [Storage — vector_store](/deployment-guides/config-json/storage#vector_store). + + + + + + +### OpenTelemetry (OTel) + +Exports distributed traces to any OTel-compatible collector (Jaeger, Zipkin, Tempo, Datadog via OTLP, etc.). + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `config.collector_url` | Yes | — | OTLP collector endpoint | +| `config.trace_type` | Yes | — | Trace format: `"genai_extension"`, `"vercel"`, or `"open_inference"` | +| `config.protocol` | Yes | — | `"http"` or `"grpc"` | +| `config.service_name` | No | `"bifrost"` | Service name reported to the collector | +| `config.metrics_enabled` | No | `false` | Enable push-based OTLP metrics export | +| `config.metrics_endpoint` | No | — | OTLP metrics endpoint URL | +| `config.metrics_push_interval` | No | `15` | Metrics push interval in seconds | +| `config.headers` | No | — | Custom headers for the collector (supports `env.` prefix) | +| `config.insecure` | No | `false` | Skip TLS verification | +| `config.tls_ca_cert` | No | — | Path to TLS CA certificate | + +```json +{ + "plugins": [ + { + "name": "otel", + "enabled": true, + "config": { + "collector_url": "http://otel-collector:4318", + "trace_type": "genai_extension", + "protocol": "http", + "service_name": "bifrost-gateway" + } + } + ] +} +``` + +**With authentication headers:** + +```json +{ + "plugins": [ + { + "name": "otel", + "enabled": true, + "config": { + "collector_url": "https://otel.example.com:4318", + "trace_type": "open_inference", + "protocol": "http", + "service_name": "bifrost", + "headers": { + "Authorization": "env.OTEL_AUTH_HEADER" + } + } + } + ] +} +``` + +**With OTLP metrics export:** + +```json +{ + "plugins": [ + { + "name": "otel", + "enabled": true, + "config": { + "collector_url": "http://otel-collector:4318", + "trace_type": "genai_extension", + "protocol": "http", + "metrics_enabled": true, + "metrics_endpoint": "http://otel-collector:4318/v1/metrics", + "metrics_push_interval": 30 + } + } + ] +} +``` + + + + + +### Maxim Observability + +Sends request traces to the [Maxim](https://www.getmaxim.ai) observability platform. + +| Field | Required | Description | +|-------|----------|-------------| +| `config.api_key` | Yes | Maxim API key (use `env.` prefix) | +| `config.log_repo_id` | No | Default Maxim logger repository ID | + +```json +{ + "plugins": [ + { + "name": "maxim", + "enabled": true, + "config": { + "api_key": "env.MAXIM_API_KEY", + "log_repo_id": "your-log-repo-id" + } + } + ] +} +``` + + + + + +### Datadog + + +Datadog is an **enterprise-only** plugin and is silently ignored in OSS builds. + + +Sends APM traces and metrics to a Datadog Agent. + +| Field | Default | Description | +|-------|---------|-------------| +| `config.agent_addr` | `"localhost:8126"` | Datadog Agent address for APM traces | +| `config.service_name` | `"bifrost"` | Service name in Datadog | +| `config.env` | — | Environment tag (e.g. `"production"`, `"staging"`) | +| `config.version` | — | Service version tag | +| `config.enable_traces` | `true` | Enable APM trace collection | +| `config.custom_tags` | `{}` | Additional key/value tags for all traces and metrics | + +```json +{ + "plugins": [ + { + "name": "datadog", + "enabled": true, + "config": { + "agent_addr": "datadog-agent:8126", + "service_name": "bifrost", + "env": "production", + "enable_traces": true, + "custom_tags": { + "team": "platform", + "region": "us-east-1" + } + } + } + ] +} +``` + + + + + +--- + +## Custom / Dynamic Plugins + +Load a custom Go plugin binary or WASM plugin at startup using the `path` field. Custom plugins must implement one of the Bifrost plugin interfaces. + +```json +{ + "plugins": [ + { + "name": "my-custom-auth", + "enabled": true, + "path": "/app/plugins/my-custom-auth.so", + "config": { + "auth_endpoint": "env.AUTH_SERVICE_URL" + } + } + ] +} +``` + +**WASM plugin:** + +```json +{ + "plugins": [ + { + "name": "my-wasm-plugin", + "enabled": true, + "path": "/app/plugins/my-plugin.wasm", + "config": {} + } + ] +} +``` + +See [Writing Go Plugins](/plugins/writing-go-plugin) and [Writing WASM Plugins](/plugins/writing-wasm-plugin) for implementation guides. + +**Placement and ordering (DB-backed only):** + +When creating plugins dynamically via the DB/UI (rather than `config.json`), you can specify their execution order: + +| `placement` | When it runs | +|-------------|-------------| +| `pre_builtin` | Before all built-in plugins | +| `builtin` | Alongside built-in plugins (by `order`) | +| `post_builtin` | After all built-in plugins (default) | + +Within a placement group, lower `order` values run earlier. diff --git a/docs/deployment-guides/config-json/providers.mdx b/docs/deployment-guides/config-json/providers.mdx new file mode 100644 index 0000000000..ca07e0e5f8 --- /dev/null +++ b/docs/deployment-guides/config-json/providers.mdx @@ -0,0 +1,755 @@ +--- +title: "Provider Setup" +description: "Configure LLM providers in config.json — API keys, cloud-native auth, per-provider network settings, and self-hosted endpoints" +icon: "plug" +--- + +All providers are configured under `providers` in `config.json`. Each provider entry contains a `keys` array where every key has a `name`, `value`, `models`, and `weight`, plus optional provider-specific config objects. + +**Supplying credentials:** + +Use the `env.` prefix to reference environment variables — never put API keys directly in `config.json`: + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "name": "primary", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + } + } +} +``` + +--- + +## Common Provider Fields + +Every key object supports these fields: + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Unique name for this key (used in logs and virtual key pin) | +| `value` | string | API key value or `env.VAR_NAME` reference | +| `models` | array | Models this key serves. `["*"]` = all models | +| `weight` | float | Load balancing weight. Higher = more traffic | +| `aliases` | object | Map logical name → actual model name for this key | +| `use_for_batch_api` | boolean | Mark key as eligible for batch API calls | + +Per-provider `network_config` options (applies to all standard providers): + +| Field | Type | Description | +|-------|------|-------------| +| `default_request_timeout_in_seconds` | integer | Per-request timeout | +| `max_retries` | integer | Retry attempts on transient errors | +| `retry_backoff_initial` | integer | Initial backoff in milliseconds | +| `retry_backoff_max` | integer | Maximum backoff in milliseconds | +| `max_conns_per_host` | integer | Max TCP connections to the provider endpoint (default: 5000) | +| `extra_headers` | object | Static headers added to every provider request | +| `stream_idle_timeout_in_seconds` | integer | Idle timeout per stream chunk (default: 60) | +| `insecure_skip_verify` | boolean | Disable TLS verification (last resort only) | +| `ca_cert_pem` | string | PEM-encoded CA for self-signed or private CA endpoints | + +Concurrency and buffering per provider: + +| Field | Type | Description | +|-------|------|-------------| +| `concurrency_and_buffer_size.concurrency` | integer | Max concurrent requests to this provider | +| `concurrency_and_buffer_size.buffer_size` | integer | Request queue depth | + +--- + + + + + +### OpenAI + +Supports multiple keys with weighted load balancing. Mark one key with `use_for_batch_api: true` to designate it for the Batch API. + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "name": "openai-primary", + "value": "env.OPENAI_KEY_1", + "models": ["*"], + "weight": 2.0 + }, + { + "name": "openai-secondary", + "value": "env.OPENAI_KEY_2", + "models": ["gpt-4o-mini"], + "weight": 1.0 + }, + { + "name": "openai-batch", + "value": "env.OPENAI_KEY_BATCH", + "models": ["*"], + "weight": 1.0, + "use_for_batch_api": true + } + ], + "network_config": { + "default_request_timeout_in_seconds": 120, + "max_retries": 3, + "retry_backoff_initial": 500, + "retry_backoff_max": 5000 + } + } + } +} +``` + + + + + +### Anthropic + +```json +{ + "providers": { + "anthropic": { + "keys": [ + { + "name": "anthropic-primary", + "value": "env.ANTHROPIC_KEY_1", + "models": ["*"], + "weight": 1.0 + }, + { + "name": "anthropic-secondary", + "value": "env.ANTHROPIC_KEY_2", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 180 + } + } + } +} +``` + +**Override Anthropic beta headers** (optional): + +```json +{ + "providers": { + "anthropic": { + "keys": [ + { + "name": "primary", + "value": "env.ANTHROPIC_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "beta_header_overrides": { + "redact-thinking-": true + } + } + } + } +} +``` + + + + + +### Azure OpenAI + +Azure requires `azure_key_config` on every key with `endpoint` and `api_version`. List your Azure deployment names in `models` — Bifrost routes requests using the model name as the deployment name. If your deployment names differ from the model names you use in requests, add an `aliases` map on the key. + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-primary", + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + +Set environment variables: + +```bash +export AZURE_API_KEY="your-azure-api-key" +export AZURE_ENDPOINT="https://your-resource.openai.azure.com" +export AZURE_API_VERSION="2024-10-21" +``` + + + + +When `value` is empty or omitted, Bifrost uses `DefaultAzureCredential` — which resolves credentials from Workload Identity, VM managed identity, or `az login`. + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-workload-identity", + "value": "", + "models": ["gpt-4o"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + + + + +**Deployment name aliases** — when your Azure deployment names differ from the model names in requests, use `aliases`: + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-primary", + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o"], + "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-prod-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + +**Multi-region failover** (two keys, different regions): + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "eastus", + "value": "env.AZURE_KEY_EAST", + "models": ["gpt-4o"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT_EAST", + "api_version": "env.AZURE_API_VERSION" + } + }, + { + "name": "westus", + "value": "env.AZURE_KEY_WEST", + "models": ["gpt-4o"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT_WEST", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + + + + + +### AWS Bedrock + +Bedrock requires `bedrock_key_config` with at minimum a `region`. Three auth modes: + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-static", + "value": "", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1", + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY" + } + } + ] + } + } +} +``` + + + + +When only `region` is set, Bifrost inherits credentials from the AWS SDK default chain — IRSA (IAM Roles for Service Accounts), EC2 instance profile, or `AWS_*` env vars. + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-iam", + "value": "", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-assumerole", + "value": "", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-west-2", + "role_arn": "env.AWS_ROLE_ARN", + "external_id": "env.AWS_EXTERNAL_ID", + "session_name": "bifrost-session" + } + } + ] + } + } +} +``` + + + + +**Model aliases** (map logical names to Bedrock inference profile IDs): + +```json +{ + "bedrock_key_config": { + "region": "us-east-1" + }, + "aliases": { + "claude-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-haiku": "us.anthropic.claude-3-5-haiku-20241022-v1:0" + } +} +``` + +**Batch API — S3 configuration:** + +```json +{ + "bedrock_key_config": { + "region": "us-east-1", + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "batch_s3_config": { + "buckets": [ + { + "bucket_name": "my-bedrock-batch-bucket", + "prefix": "batch/", + "is_default": true + } + ] + } + } +} +``` + + + + + +### Google Vertex AI + +Vertex requires `vertex_key_config` with `project_id` and `region`. Two auth modes: + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-sa", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_AUTH_CREDENTIALS" + } + } + ] + } + } +} +``` + +`VERTEX_AUTH_CREDENTIALS` should contain the base64-encoded service account JSON. + + + + +When `auth_credentials` is omitted, Bifrost calls `google.FindDefaultCredentials` — which resolves to GKE Workload Identity, GCE metadata server, or `gcloud auth application-default login`. + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-workload-identity", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "my-gcp-project", + "region": "us-central1" + } + } + ] + } + } +} +``` + + + + + + + + +### Standard API-Key Providers + +These providers follow the same simple pattern — one or more keys with weights. Replace the provider name and env var name accordingly. + +```json +{ + "providers": { + "groq": { + "keys": [ + { + "name": "groq-primary", + "value": "env.GROQ_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "gemini": { + "keys": [ + { + "name": "gemini-primary", + "value": "env.GEMINI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "mistral": { + "keys": [ + { + "name": "mistral-primary", + "value": "env.MISTRAL_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "cohere": { + "keys": [{ "name": "cohere-main", "value": "env.COHERE_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "perplexity": { + "keys": [{ "name": "perplexity-main", "value": "env.PERPLEXITY_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "xai": { + "keys": [{ "name": "xai-main", "value": "env.XAI_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "cerebras": { + "keys": [{ "name": "cerebras-main", "value": "env.CEREBRAS_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "openrouter": { + "keys": [{ "name": "openrouter-main", "value": "env.OPENROUTER_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "nebius": { + "keys": [{ "name": "nebius-main", "value": "env.NEBIUS_API_KEY", "models": ["*"], "weight": 1.0 }] + } + } +} +``` + + + + + +### Self-Hosted Providers + +Self-hosted providers point to a URL you operate. No API key is typically required (`"value": ""`). + + + + +```json +{ + "providers": { + "ollama": { + "keys": [ + { + "name": "ollama-local", + "value": "", + "models": ["*"], + "weight": 1.0, + "ollama_key_config": { + "url": "http://localhost:11434" + } + } + ] + } + } +} +``` + +Using an env var for the URL (useful across environments): + +```json +{ + "ollama_key_config": { + "url": "env.OLLAMA_URL" + } +} +``` + + + + +vLLM instances are model-specific — one key per served model: + +```json +{ + "providers": { + "vllm": { + "keys": [ + { + "name": "vllm-llama3-70b", + "value": "", + "models": ["llama-3-70b"], + "weight": 1.0, + "vllm_key_config": { + "url": "http://vllm-server:8000", + "model_name": "meta-llama/Meta-Llama-3-70B-Instruct" + } + }, + { + "name": "vllm-mistral", + "value": "", + "models": ["mistral-7b"], + "weight": 1.0, + "vllm_key_config": { + "url": "http://vllm-mistral:8000", + "model_name": "mistralai/Mistral-7B-Instruct-v0.3" + } + } + ] + } + } +} +``` + + + + +```json +{ + "providers": { + "sgl": { + "keys": [ + { + "name": "sgl-main", + "value": "", + "models": ["*"], + "weight": 1.0, + "sgl_key_config": { + "url": "http://sgl-router:30000" + } + } + ] + } + } +} +``` + + + + +These providers use `aliases` to map logical model names to provider-specific IDs: + +```json +{ + "providers": { + "huggingface": { + "keys": [ + { + "name": "hf-main", + "value": "env.HF_API_KEY", + "models": ["llama-3", "mixtral"], + "weight": 1.0, + "aliases": { + "llama-3": "meta-llama/Meta-Llama-3-8B-Instruct", + "mixtral": "mistralai/Mixtral-8x7B-Instruct-v0.1" + } + } + ] + }, + "replicate": { + "keys": [ + { + "name": "replicate-main", + "value": "env.REPLICATE_API_KEY", + "models": ["llama-3"], + "weight": 1.0, + "aliases": { + "llama-3": "meta/meta-llama-3-70b-instruct" + }, + "replicate_key_config": { + "use_deployments_endpoint": false + } + } + ] + } + } +} +``` + + + + + + + + +--- + +## Proxy Configuration + +Route provider traffic through an HTTP or SOCKS5 proxy: + +```json +{ + "providers": { + "openai": { + "keys": [ + { "name": "primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 } + ], + "proxy_config": { + "type": "http", + "url": "http://proxy.corp.example.com:3128", + "username": "env.PROXY_USER", + "password": "env.PROXY_PASS" + } + } + } +} +``` + +| Field | Type | Options | +|-------|------|---------| +| `proxy_config.type` | string | `"none"`, `"http"`, `"socks5"`, `"environment"` | +| `proxy_config.url` | string | Proxy server URL | +| `proxy_config.username` | string | Proxy auth username | +| `proxy_config.password` | string | Proxy auth password (`env.` supported) | +| `proxy_config.ca_cert_pem` | string | PEM CA for TLS-intercepting proxies | + +Use `"type": "environment"` to pick up `HTTP_PROXY` / `HTTPS_PROXY` env vars automatically. + +--- + +## Multi-Provider Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "providers": { + "openai": { + "keys": [ + { "name": "openai-primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 2.0 } + ] + }, + "anthropic": { + "keys": [ + { "name": "anthropic-primary", "value": "env.ANTHROPIC_API_KEY", "models": ["*"], "weight": 1.0 } + ] + }, + "groq": { + "keys": [ + { "name": "groq-primary", "value": "env.GROQ_API_KEY", "models": ["*"], "weight": 1.0 } + ] + } + } +} +``` + +With three providers and the weights above, traffic is distributed: 50% OpenAI, 25% Anthropic, 25% Groq. If any provider returns an error, Bifrost automatically retries on the next key or provider. diff --git a/docs/deployment-guides/config-json/schema-reference.mdx b/docs/deployment-guides/config-json/schema-reference.mdx new file mode 100644 index 0000000000..45b9b826ce --- /dev/null +++ b/docs/deployment-guides/config-json/schema-reference.mdx @@ -0,0 +1,202 @@ +--- +title: "Schema Reference" +description: "All top-level keys available in config.json, their types, and where each is documented" +icon: "brackets-curly" +--- + + +The live schema is published at [`https://www.getbifrost.ai/schema`](https://www.getbifrost.ai/schema). Add `"$schema": "https://www.getbifrost.ai/schema"` to your `config.json` for IDE autocomplete and inline validation. + + +This page is a concise reference for every top-level key in `config.json`. Click the **Guide** links for full field-by-field documentation. + +--- + +## Top-Level Keys + +| Key | Type | Description | Guide | +|-----|------|-------------|-------| +| `$schema` | string | Schema URL for IDE validation. Set to `"https://www.getbifrost.ai/schema"` | — | +| `encryption_key` | string | AES-256 key (derived via Argon2id). Accepts `env.VAR` prefix. Also read from `BIFROST_ENCRYPTION_KEY` env var | [Client](/deployment-guides/config-json/client#encryption-key) | +| `client` | object | Worker pool, logging, CORS, auth enforcement, header filtering, MCP, compat shims | [Client](/deployment-guides/config-json/client) | +| `providers` | object | LLM provider API keys, network settings, concurrency | [Providers](/deployment-guides/config-json/providers) | +| `governance` | object | Admin auth, virtual keys, budgets, rate limits, routing rules, customers, teams | [Governance](/deployment-guides/config-json/governance) | +| `guardrails_config` | object | Content moderation providers and CEL-based rules *(enterprise only)* | [Guardrails](/deployment-guides/config-json/guardrails) | +| `config_store` | object | Configuration database backend — SQLite, PostgreSQL, or disabled (file-only mode) | [Storage](/deployment-guides/config-json/storage#config_store) | +| `logs_store` | object | Request/response log database — SQLite, PostgreSQL + optional S3/GCS offload | [Storage](/deployment-guides/config-json/storage#logs_store) | +| `vector_store` | object | Vector database for semantic cache — Weaviate, Redis, Qdrant, Pinecone, Valkey | [Storage](/deployment-guides/config-json/storage#vector_store) | +| `plugins` | array | Opt-in plugins: `semantic_cache`, `otel`, `maxim`, `datadog`, custom | [Plugins](/deployment-guides/config-json/plugins) | +| `framework` | object | Model pricing catalog URL and sync interval | [Framework](#framework) | +| `mcp` | object | MCP server and tool configuration | — | +| `websocket` | object | WebSocket / Realtime API connection pool tuning | [WebSocket](#websocket) | +| `auth_config` | object | **Deprecated** — use `governance.auth_config` | [Client](/deployment-guides/config-json/client#authentication) | + +--- + +## `version` + +Controls how empty arrays in allow-list fields (`models`, `allowed_models`, `key_ids`, `tools_to_execute`) are interpreted: + +| Value | Behaviour | +|-------|-----------| +| `2` *(default, v1.5.0+)* | Empty array = **deny all**; `["*"]` = allow all | +| `1` *(v1.4.x compat)* | Empty array = **allow all** | + +Omitting `version` uses v2 semantics. Set `"version": 1` only if you are migrating from v1.4.x and need the old behaviour temporarily. + +--- + +## `client` + +Controls the worker pool, logging pipeline, security, and SDK shims. All fields are optional. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `initial_pool_size` | integer | `300` | Pre-allocated goroutines per provider queue | +| `drop_excess_requests` | boolean | `false` | Return HTTP 429 when queue is full | +| `enable_logging` | boolean | `true`* | Persist request/response logs (`*` auto-enabled when `logs_store` is set) | +| `disable_content_logging` | boolean | `false` | Strip message content from logs | +| `log_retention_days` | integer | `365` | Days to retain log entries | +| `logging_headers` | array | `[]` | HTTP headers to capture in log metadata | +| `enforce_auth_on_inference` | boolean | `false` | Require a virtual key on every `/v1/*` request | +| `allow_direct_keys` | boolean | `false` | Allow callers to pass provider API keys directly | +| `allowed_origins` | array | `["*"]` | CORS allowed origins | +| `max_request_body_size_mb` | integer | `100` | Maximum request body in MB | +| `whitelisted_routes` | array | `[]` | Routes that bypass auth middleware | +| `allowed_headers` | array | `[]` | Additional headers permitted for CORS/WebSocket | +| `required_headers` | array | `[]` | Headers that must be present on every request | +| `header_filter_config` | object | — | `allowlist` / `denylist` for `x-bf-eh-*` forwarded headers | +| `prometheus_labels` | array | `[]` | Custom labels for all Prometheus metrics | +| `compat` | object | — | SDK compatibility shims (`should_drop_params`, `convert_text_to_chat`, etc.) | +| `mcp_agent_depth` | integer | `10` | Max tool-call recursion depth | +| `mcp_tool_execution_timeout` | integer | `30` | Per-tool execution timeout in seconds | +| `mcp_tool_sync_interval` | integer | `10` | Tool sync interval in minutes (`0` = disabled) | +| `mcp_disable_auto_tool_inject` | boolean | `false` | Disable automatic MCP tool injection | +| `async_job_result_ttl` | integer | `3600` | TTL for async job results in seconds | +| `disable_db_pings_in_health` | boolean | `false` | Exclude DB connectivity from `/health` | +| `routing_chain_max_depth` | integer | `10` | Max routing rule chain evaluation depth | + +Full documentation: [Client Configuration](/deployment-guides/config-json/client). + +--- + +## `providers` + +Keyed by provider name. Each entry contains a `keys` array and optional `network_config`, `concurrency_and_buffer_size`, `proxy_config`. + +Supported provider keys: `openai`, `anthropic`, `azure`, `bedrock`, `vertex`, `gemini`, `mistral`, `groq`, `cohere`, `perplexity`, `xai`, `cerebras`, `openrouter`, `nebius`, `fireworks`, `parasail`, `huggingface`, `replicate`, `ollama`, `vllm`, `sgl`, `elevenlabs`, `runway`. + +Full documentation: [Provider Setup](/deployment-guides/config-json/providers). + +--- + +## `governance` + +Seeds governance resources at startup. All sub-keys are optional arrays. + +| Sub-key | Description | +|---------|-------------| +| `auth_config` | Admin username/password auth for the dashboard | +| `virtual_keys` | Scoped API tokens with provider/model allowlists | +| `budgets` | Spend caps in USD over a rolling window | +| `rate_limits` | Request and token rate limits | +| `customers` | Customer entities (attach budgets/rate limits) | +| `teams` | Team entities (attach to customers, budgets, rate limits) | +| `routing_rules` | CEL-based dynamic provider/model routing | +| `pricing_overrides` | Scoped per-model pricing overrides | +| `model_configs` | Per-model rate limit and budget configurations | + +Full documentation: [Governance](/deployment-guides/config-json/governance). + +--- + +## `guardrails_config` + +Enterprise-only. Two sub-keys: `guardrail_providers` (array) and `guardrail_rules` (array). + +Full documentation: [Guardrails](/deployment-guides/config-json/guardrails). + +--- + +## `config_store`, `logs_store`, `vector_store` + +Storage backends. Each has `enabled` (boolean), `type` (string), and `config` (object). + +| Store | Types | +|-------|-------| +| `config_store` | `"sqlite"`, `"postgres"` | +| `logs_store` | `"sqlite"`, `"postgres"` (+ optional `object_storage`) | +| `vector_store` | `"weaviate"`, `"redis"`, `"qdrant"`, `"pinecone"` (`"redis"` also covers Valkey-compatible endpoints) | + +Full documentation: [Storage](/deployment-guides/config-json/storage). + +--- + +## `framework` + +Controls model pricing catalog sync: + +```json +{ + "framework": { + "pricing": { + "pricing_url": "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json", + "pricing_sync_interval": 86400 + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `pricing.pricing_url` | LiteLLM catalog | URL of a model pricing JSON file | +| `pricing.pricing_sync_interval` | `86400` | Sync interval in seconds (minimum: `3600`) | + +--- + +## `websocket` + +Optional tuning for the WebSocket gateway (Responses API WebSocket mode, Realtime API). WebSocket is always enabled. + +```json +{ + "websocket": { + "max_connections_per_user": 100, + "transcript_buffer_size": 100, + "pool": { + "max_idle_per_key": 50, + "max_total_connections": 1000, + "idle_timeout_seconds": 600, + "max_connection_lifetime_seconds": 7200 + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `max_connections_per_user` | `100` | Max concurrent WebSocket connections per user | +| `transcript_buffer_size` | `100` | Transcript entries buffered for Realtime API mid-session fallback | +| `pool.max_idle_per_key` | `50` | Max idle upstream connections per provider/key | +| `pool.max_total_connections` | `1000` | Max total idle upstream connections | +| `pool.idle_timeout_seconds` | `600` | Evict idle connections after this many seconds | +| `pool.max_connection_lifetime_seconds` | `7200` | Max lifetime of any upstream connection | + +--- + +## Minimal Valid Config + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + "providers": { + "openai": { + "keys": [ + { "name": "primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 } + ] + } + }, + "config_store": { "enabled": false } +} +``` diff --git a/docs/deployment-guides/config-json/storage.mdx b/docs/deployment-guides/config-json/storage.mdx new file mode 100644 index 0000000000..fd4bdbde97 --- /dev/null +++ b/docs/deployment-guides/config-json/storage.mdx @@ -0,0 +1,540 @@ +--- +title: "Storage" +description: "Configure Bifrost storage backends in config.json — config_store, logs_store, vector_store, and object storage for logs" +icon: "database" +--- + +Bifrost persists two types of data — **config** (providers, virtual keys, governance rules) and **logs** (request/response records). Each has its own store. A **vector store** is required for semantic caching. + +| Store | Purpose | Backends | +|-------|---------|---------| +| `config_store` | Provider configs, virtual keys, governance rules | SQLite, PostgreSQL | +| `logs_store` | Request/response logs shown in UI | SQLite, PostgreSQL + optional S3/GCS offload | +| `vector_store` | Semantic response caching | Weaviate, Redis, Valkey, Qdrant, Pinecone | + + +If you use PostgreSQL for any store, the target database must be **UTF8 encoded**. See [PostgreSQL UTF8 Requirement](/quickstart/gateway/setting-up#postgresql-utf8-requirement). + + +--- + +## config_store + + +When `config_store` is disabled (or absent), all configuration is loaded from `config.json` at startup only — the Web UI is disabled and changes require a restart. See [Two Configuration Modes](/deployment-guides/config-json#two-configuration-modes). + + + + + + +### SQLite (Default) + +Simplest setup — no external database required. Bifrost stores configuration in a local SQLite file. + +```json +{ + "config_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./config.db" + } + } +} +``` + +| Field | Description | +|-------|-------------| +| `config.path` | Path to the SQLite file (relative to app-dir, or absolute) | + + + + + +### PostgreSQL + +Production-grade storage suitable for high-availability and high-throughput deployments. + +```json +{ + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 5, + "max_open_conns": 50 + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `host` | — | PostgreSQL host (supports `env.` prefix) | +| `port` | — | PostgreSQL port (as string) | +| `user` | — | Database user (supports `env.` prefix) | +| `password` | — | Database password (supports `env.` prefix). Leave empty for IAM role auth. | +| `db_name` | — | Database name | +| `ssl_mode` | — | `"disable"`, `"require"`, `"verify-ca"`, `"verify-full"` | +| `max_idle_conns` | `5` | Maximum idle connections in the pool | +| `max_open_conns` | `50` | Maximum open connections to the database | + + + + + +### Disabled (file-only mode) + +Use this when you want Bifrost to read all configuration from `config.json` only — no database, no Web UI. + +```json +{ + "config_store": { + "enabled": false + } +} +``` + +This is the recommended setup for [multinode OSS deployments](/deployment-guides/how-to/multinode) where a shared `config.json` is the single source of truth. + + + + + +--- + +## logs_store + + + + + +### SQLite + +```json +{ + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./logs.db" + } + } +} +``` + + + + + +### PostgreSQL + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 10, + "max_open_conns": 100 + } + } +} +``` + +For high log volumes, increase `max_open_conns`: + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 10, + "max_open_conns": 200 + }, + "retention_days": 90 + } +} +``` + + + + + +```json +{ + "logs_store": { + "enabled": false + } +} +``` + + + + + +### Log Retention + +Set `retention_days` to automatically purge old log entries. `0` disables retention-based cleanup. + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { "...": "..." }, + "retention_days": 90 + } +} +``` + +### Object Storage for Logs + +Offload large request/response payloads from the database to S3 or GCS. The database retains only lightweight index records; payloads are fetched on demand. + + + + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { "...": "..." }, + "object_storage": { + "type": "s3", + "bucket": "env.S3_BUCKET", + "prefix": "bifrost", + "compress": true, + "region": "us-east-1", + "access_key_id": "env.S3_ACCESS_KEY_ID", + "secret_access_key": "env.S3_SECRET_ACCESS_KEY" + } + } +} +``` + +**IAM role (instance profile / IRSA)** — omit `access_key_id` and `secret_access_key`: + +```json +{ + "object_storage": { + "type": "s3", + "bucket": "bifrost-logs", + "region": "us-east-1", + "compress": true, + "role_arn": "arn:aws:iam::123456789012:role/BifrostS3Role" + } +} +``` + +| Field | Description | +|-------|-------------| +| `bucket` | S3 bucket name (supports `env.` prefix) | +| `prefix` | Key prefix for stored objects (default: `"bifrost"`) | +| `compress` | Enable gzip compression (default: `false`) | +| `region` | AWS region | +| `access_key_id` | AWS access key ID (omit for default credential chain) | +| `secret_access_key` | AWS secret access key | +| `session_token` | STS temporary credentials session token | +| `role_arn` | IAM role ARN for STS AssumeRole | +| `endpoint` | Custom endpoint for MinIO / Cloudflare R2 | +| `force_path_style` | Use path-style URLs (required for MinIO, default: `false`) | + + + + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { "...": "..." }, + "object_storage": { + "type": "gcs", + "bucket": "bifrost-logs", + "prefix": "bifrost", + "compress": true, + "project_id": "env.GCP_PROJECT_ID", + "credentials_json": "env.GCS_CREDENTIALS_JSON" + } + } +} +``` + +Omit `credentials_json` to use Application Default Credentials (Workload Identity, GCE metadata, `gcloud auth`). + +| Field | Description | +|-------|-------------| +| `project_id` | GCP project ID (supports `env.` prefix) | +| `credentials_json` | Service account JSON or path — omit for ADC | + + + + +```json +{ + "object_storage": { + "type": "s3", + "bucket": "bifrost-logs", + "prefix": "bifrost", + "compress": false, + "region": "us-east-1", + "endpoint": "http://minio.internal:9000", + "access_key_id": "env.MINIO_ACCESS_KEY", + "secret_access_key": "env.MINIO_SECRET_KEY", + "force_path_style": true + } +} +``` + + + + +--- + +## vector_store + +A vector store is required for [semantic caching](/features/semantic-caching). Choose from Weaviate, Redis/Valkey, Qdrant, or Pinecone. + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "scheme": "http", + "host": "localhost:8080", + "api_key": "env.WEAVIATE_API_KEY", + "grpc_config": { + "host": "localhost:50051", + "secured": false + } + } + } +} +``` + +| Field | Required | Description | +|-------|----------|-------------| +| `scheme` | Yes | `"http"` or `"https"` | +| `host` | Yes | Weaviate server host and port | +| `api_key` | No | Weaviate API key (supports `env.` prefix) | +| `grpc_config.host` | No | gRPC host for faster vector operations | +| `grpc_config.secured` | No | Use TLS for gRPC connection | + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "redis", + "config": { + "addr": "env.REDIS_ADDR", + "password": "env.REDIS_PASSWORD", + "db": 0, + "use_tls": false + } + } +} +``` + +**AWS MemoryDB (cluster mode):** + +```json +{ + "vector_store": { + "enabled": true, + "type": "redis", + "config": { + "addr": "env.MEMORYDB_ENDPOINT", + "password": "env.MEMORYDB_PASSWORD", + "use_tls": true, + "cluster_mode": true + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `addr` | — | Redis/Valkey address `host:port` (supports `env.` prefix) | +| `password` | — | Redis AUTH password (supports `env.` prefix) | +| `db` | `0` | Redis database number | +| `use_tls` | `false` | Enable TLS | +| `cluster_mode` | `false` | Enable cluster mode (required for MemoryDB; `db` must be `0`) | +| `pool_size` | — | Maximum socket connections | + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "qdrant", + "config": { + "host": "env.QDRANT_HOST", + "port": 6334, + "api_key": "env.QDRANT_API_KEY", + "use_tls": false + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `host` | — | Qdrant server host (supports `env.` prefix) | +| `port` | `6334` | gRPC port | +| `api_key` | — | API key (supports `env.` prefix) | +| `use_tls` | `false` | Enable TLS | + + + + + +Pinecone is external-only. + +```json +{ + "vector_store": { + "enabled": true, + "type": "pinecone", + "config": { + "api_key": "env.PINECONE_API_KEY", + "index_host": "env.PINECONE_INDEX_HOST" + } + } +} +``` + +| Field | Description | +|-------|-------------| +| `api_key` | Pinecone API key (supports `env.` prefix) | +| `index_host` | Index host from Pinecone console (e.g. `your-index.svc.us-east1-gcp.pinecone.io`) | + + + + + +--- + +## Mixed Backend Example + +Run the config store on PostgreSQL (for UI) while keeping logs on SQLite (simpler, cheaper for append-heavy workloads): + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require" + } + }, + + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./logs.db" + } + } +} +``` + +--- + +## Full Storage Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 5, + "max_open_conns": 50 + } + }, + + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 10, + "max_open_conns": 100 + }, + "retention_days": 90, + "object_storage": { + "type": "s3", + "bucket": "env.S3_BUCKET", + "region": "us-east-1", + "compress": true, + "access_key_id": "env.S3_ACCESS_KEY_ID", + "secret_access_key": "env.S3_SECRET_ACCESS_KEY" + } + }, + + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "scheme": "http", + "host": "weaviate:8080" + } + } +} +``` diff --git a/docs/deployment-guides/helm/guardrails.mdx b/docs/deployment-guides/helm/guardrails.mdx index 60ec2710d5..4604b426e4 100644 --- a/docs/deployment-guides/helm/guardrails.mdx +++ b/docs/deployment-guides/helm/guardrails.mdx @@ -107,23 +107,6 @@ bifrost: rules: {} # optional: inline rule map ``` - - - -```yaml -bifrost: - guardrails: - providers: - - id: 5 - provider_name: "patronus-ai" - policy_name: "patronus-safety" - enabled: true - timeout: 20 - config: - api_key: "env.PATRONUS_API_KEY" - environment: "production" # production | development -``` - @@ -276,4 +259,4 @@ helm install bifrost bifrost/bifrost \ --set env[0].name=AZURE_CONTENT_SAFETY_KEY \ --set env[0].valueFrom.secretKeyRef.name=azure-content-safety \ --set env[0].valueFrom.secretKeyRef.key=key -``` +``` \ No newline at end of file diff --git a/docs/deployment-guides/helm/plugins.mdx b/docs/deployment-guides/helm/plugins.mdx index f02303120b..79a4c4f788 100644 --- a/docs/deployment-guides/helm/plugins.mdx +++ b/docs/deployment-guides/helm/plugins.mdx @@ -6,15 +6,15 @@ icon: "puzzle-piece" Plugins are configured under `bifrost.plugins`. Each plugin is independently enabled/disabled. Pre-hooks run in registration order; post-hooks run in reverse order. + +**Telemetry, logging, and governance are auto-loaded built-ins** — they are always active and do not need to be explicitly enabled. Their configuration lives in `bifrost.client.*` and `bifrost.governance.*`, not in the `plugins` block. + +The `plugins` block controls the opt-in plugins: `semanticCache`, `otel`, `datadog`, `maxim`, and custom plugins. + + ```yaml bifrost: plugins: - telemetry: - enabled: true - logging: - enabled: true - governance: - enabled: true semanticCache: enabled: false otel: @@ -24,17 +24,15 @@ bifrost: ``` ```bash -# Enable plugins at install time +# Enable an opt-in plugin at install time helm install bifrost bifrost/bifrost \ --set image.tag=v1.4.11 \ - --set bifrost.plugins.telemetry.enabled=true \ - --set bifrost.plugins.logging.enabled=true \ - --set bifrost.plugins.governance.enabled=true + --set bifrost.plugins.otel.enabled=true # Or upgrade to enable a plugin without touching other values helm upgrade bifrost bifrost/bifrost \ --reuse-values \ - --set bifrost.plugins.otel.enabled=true + --set bifrost.plugins.semanticCache.enabled=true ``` --- @@ -45,39 +43,21 @@ helm upgrade bifrost bifrost/bifrost \ ### Telemetry (Prometheus) -Exposes Prometheus metrics at `GET /metrics`. - -| Parameter | Description | Default | -|-----------|-------------|---------| -| `bifrost.plugins.telemetry.enabled` | Enable Prometheus metrics | `false` | -| `bifrost.plugins.telemetry.config.custom_labels` | Extra labels attached to every metric | `[]` | -| `bifrost.plugins.telemetry.config.push_gateway.enabled` | Push metrics to a Prometheus Push Gateway | `false` | -| `bifrost.plugins.telemetry.config.push_gateway.push_gateway_url` | Push Gateway URL | `""` | -| `bifrost.plugins.telemetry.config.push_gateway.job_name` | Job label | `"bifrost"` | -| `bifrost.plugins.telemetry.config.push_gateway.push_interval` | Push interval in seconds | `15` | + +Telemetry is **always active** — it cannot be disabled. You do not need to set `bifrost.plugins.telemetry.enabled`. + -**Basic setup:** +Exposes Prometheus metrics at `GET /metrics`. Custom labels are set via `bifrost.client.prometheusLabels`: ```yaml -# telemetry-values.yaml -image: - tag: "v1.4.11" - bifrost: - plugins: - telemetry: - enabled: true - config: - custom_labels: - - name: "environment" - value: "production" - - name: "region" - value: "us-east-1" + client: + prometheusLabels: + - "environment=production" + - "region=us-east-1" ``` ```bash -helm upgrade bifrost bifrost/bifrost --reuse-values -f telemetry-values.yaml - # Verify metrics are exposed kubectl port-forward svc/bifrost 8080:8080 & curl http://localhost:8080/metrics | head -30 @@ -118,81 +98,60 @@ serviceMonitor: ### Request/Response Logging -Persists full request and response data to the configured log store. + +Logging is **auto-loaded** when `bifrost.client.enableLogging: true` and a log store is configured. You do not need to set `bifrost.plugins.logging.enabled`. + + +Configure logging via the `client` block: | Parameter | Description | Default | |-----------|-------------|---------| -| `bifrost.plugins.logging.enabled` | Enable request/response logging | `false` | -| `bifrost.plugins.logging.config.disable_content_logging` | Strip message body from logs | `false` | -| `bifrost.plugins.logging.config.logging_headers` | HTTP headers to capture in log metadata | `[]` | +| `bifrost.client.enableLogging` | Enable request/response logging | `true` | +| `bifrost.client.disableContentLogging` | Strip message body from logs (HIPAA/PCI) | `false` | +| `bifrost.client.loggingHeaders` | HTTP headers to capture in log metadata | `[]` | ```yaml -# logging-values.yaml -image: - tag: "v1.4.11" - bifrost: - plugins: - logging: - enabled: true - config: - disable_content_logging: false # set true for HIPAA/compliance - logging_headers: - - "x-request-id" - - "x-user-id" - - "x-team-id" + client: + enableLogging: true + disableContentLogging: false # set true for HIPAA/compliance + loggingHeaders: + - "x-request-id" + - "x-user-id" + - "x-team-id" ``` ```bash -helm upgrade bifrost bifrost/bifrost --reuse-values -f logging-values.yaml -``` - -**Verify logs are being written:** - -```bash +# Verify logs are being written kubectl port-forward svc/bifrost 8080:8080 & -# Make a test request, then query logs curl -s "http://localhost:8080/api/logs?limit=5" | jq . ``` - -`bifrost.plugins.logging` controls the *plugin* (which hooks into every request). `bifrost.client.enableLogging` / `disableContentLogging` controls the *client-level* defaults. Both must be configured consistently — see the [Client Configuration](/deployment-guides/helm/client) page. - +See [Client Configuration](/deployment-guides/helm/client) for the full reference. -### Governance Plugin +### Governance + + +Governance is **always active** for OSS deployments. You do not need to set `bifrost.plugins.governance.enabled`. + -Enforces budget caps, rate limits, and virtual key policies on every request. Must be enabled alongside `bifrost.governance` resource definitions. +Virtual key enforcement is controlled by the `client` block: | Parameter | Description | Default | |-----------|-------------|---------| -| `bifrost.plugins.governance.enabled` | Enable governance enforcement | `false` | -| `bifrost.plugins.governance.config.is_vk_mandatory` | Reject requests without a virtual key | `false` | -| `bifrost.plugins.governance.config.required_headers` | Additional headers required on every request | `[]` | -| `bifrost.plugins.governance.config.is_enterprise` | Enable enterprise governance features | `false` | +| `bifrost.client.enforceAuthOnInference` | Require a virtual key (`x-bf-vk`) on every inference request | `false` | ```yaml -# governance-plugin-values.yaml -image: - tag: "v1.4.11" - bifrost: - plugins: - governance: - enabled: true - config: - is_vk_mandatory: true # require virtual key on all inference requests - required_headers: [] -``` - -```bash -helm upgrade bifrost bifrost/bifrost --reuse-values -f governance-plugin-values.yaml + client: + enforceAuthOnInference: true # require virtual key on all inference requests ``` -See the [Governance](/deployment-guides/helm/governance) page for defining budgets, rate limits, and virtual keys. +Define virtual keys, budgets, rate limits, and routing rules in `bifrost.governance.*`. See the [Governance](/deployment-guides/helm/governance) page. diff --git a/docs/docs.json b/docs/docs.json index d3d600003d..b65a5c5936 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -197,7 +197,7 @@ "icon": "bolt", "pages": [ "features/drop-in-replacement", - "features/fallbacks", + "features/retries-and-fallbacks", "features/litellm-compat", "features/keys-management", "features/async-inference", @@ -389,11 +389,7 @@ { "group": "Platform specific guides", "icon": "swatchbook", - "pages": [ - "deployment-guides/k8s", - "deployment-guides/ecs", - "deployment-guides/fly" - ] + "pages": ["deployment-guides/k8s", "deployment-guides/ecs", "deployment-guides/fly"] }, { "group": "Config as Code", @@ -414,6 +410,20 @@ "deployment-guides/helm/cluster", "deployment-guides/helm/troubleshooting" ] + }, + { + "group": "config.json", + "icon": "file-code", + "pages": [ + "deployment-guides/config-json", + "deployment-guides/config-json/schema-reference", + "deployment-guides/config-json/client", + "deployment-guides/config-json/providers", + "deployment-guides/config-json/storage", + "deployment-guides/config-json/plugins", + "deployment-guides/config-json/governance", + "deployment-guides/config-json/guardrails" + ] } ] }, @@ -683,6 +693,10 @@ ] }, "redirects": [ + { + "source": "/features/fallbacks", + "destination": "/features/retries-and-fallbacks" + }, { "source": "/quickstart/gateway/cli-agents", "destination": "/cli-agents/overview" diff --git a/docs/enterprise/vault-support.mdx b/docs/enterprise/vault-support.mdx deleted file mode 100644 index cd63a630ab..0000000000 --- a/docs/enterprise/vault-support.mdx +++ /dev/null @@ -1,133 +0,0 @@ ---- -title: "Vault Support" -description: "Secure API key management with HashiCorp Vault, AWS Secrets Manager, Google Secret Manager, and Azure Key Vault integration. Store and retrieve sensitive credentials using enterprise-grade secret management." -icon: "vault" ---- - -Bifrost's vault support enables seamless integration with enterprise-grade secret management systems, allowing you to connect to existing vaults and automatically sync virtual keys and provider API keys directly onto the Bifrost platform. - -## Overview - -The vault integration provides: - -- **Automated Key Synchronization**: Connect to your existing vault infrastructure and sync all API keys automatically -- **Periodic Key Management**: Regular synchronization ensures deprecated and archived keys are properly managed -- **Multi-Vault Support**: Compatible with HashiCorp Vault, AWS Secrets Manager, Google Secret Manager, and Azure Key Vault -- **Zero-Downtime Operations**: Keys are synced without interrupting your running services - -## Supported Vault Systems - - - - Centralized secret management for self-hosted deployments. - - - Cloud-native secret storage on AWS. - - - Secure key storage on Google Cloud Platform. - - - Key management for Microsoft Azure environments. - - - -## Key Synchronization - -### Automatic Sync Process - -Bifrost automatically synchronizes keys from your vault at regular intervals: - -1. **Discovery**: Scans the configured vault paths for API keys and virtual keys -2. **Validation**: Verifies key format and accessibility -3. **Sync**: Updates Bifrost's internal key store with new and modified keys -4. **Deprecation**: Identifies and archives keys that have been removed from the vault -5. **Notification**: Logs sync status and any issues encountered - -### Sync Configuration - -Configure synchronization behavior to match your operational requirements: - -```json -{ - "vault": { - "sync_interval": "300s", - "sync_paths": [ - "bifrost/provider-keys/*", - "bifrost/virtual-keys/*" - ], - "auto_deprecate": true, - "backup_deprecated_keys": true - } -} -``` - -#### Configuration Options - -| Option | Description | Default | -|--------|-------------|---------| -| `sync_interval` | Time between sync operations | `300s` | -| `sync_paths` | Vault paths to monitor for keys | `["bifrost/*"]` | -| `auto_deprecate` | Automatically deprecate removed keys | `true` | -| `backup_deprecated_keys` | Backup keys before deprecation | `true` | - -## Key Management Lifecycle - -### Key States - -Keys in Bifrost can have the following states: - -- **Active**: Currently in use and available for requests -- **Deprecated**: Marked for removal but still functional -- **Archived**: Removed from active use but retained for audit purposes -- **Expired**: Keys that have exceeded their validity period - -### Deprecation Process - -When keys are removed from the vault: - -1. **Detection**: Next sync cycle identifies missing keys -2. **Grace Period**: Keys enter deprecated state with configurable grace period -3. **Notification**: Administrators are notified of pending deprecation -4. **Archive**: Keys are moved to archived state after grace period expires - -```json -{ - "vault": { - "deprecation": { - "grace_period": "24h", - "notify_admins": true, - "retain_archived": "90d" - } - } -} -``` - -## Security Considerations - -### Authentication - -- **Vault Tokens**: Use time-limited tokens with minimal required permissions -- **IAM Roles**: Leverage cloud provider IAM roles for secure authentication -- **Certificate-based Auth**: Support for mutual TLS authentication where available - -### Encryption - -- **Transit Encryption**: All communication with vault systems uses TLS -- **At-Rest Encryption**: Keys are encrypted in Bifrost's internal storage -- **Key Rotation**: Automatic detection and handling of rotated vault credentials - -### Audit Trail - -Complete audit logging for all vault operations: - -```json -{ - "timestamp": "2024-01-15T10:30:00Z", - "operation": "key_sync", - "vault_type": "hashicorp", - "keys_synced": 15, - "keys_deprecated": 2, - "status": "success" -} -``` diff --git a/docs/features/fallbacks.mdx b/docs/features/fallbacks.mdx deleted file mode 100644 index 23a53c976e..0000000000 --- a/docs/features/fallbacks.mdx +++ /dev/null @@ -1,187 +0,0 @@ ---- -title: "Fallbacks" -description: "Automatic failover between AI providers and models. When your primary provider fails, Bifrost seamlessly switches to backup providers without interrupting your application." -icon: "list-check" ---- - -## Automatic Provider Failover - -Fallbacks provide automatic failover when your primary AI provider experiences issues. Whether it's rate limiting, outages, or model unavailability, Bifrost automatically tries backup providers in the order you specify until one succeeds. - -When a fallback is triggered, Bifrost treats it as a completely new request - all configured plugins (caching, governance, logging, etc.) run again for the fallback provider, ensuring consistent behavior across all providers. - -## How Fallbacks Work - -When you configure fallbacks, Bifrost follows this process: - -1. **Primary Attempt**: Tries your main provider/model first -2. **Automatic Detection**: If the primary fails (network error, rate limit, model unavailable), Bifrost detects the failure -3. **Sequential Fallbacks**: Tries each fallback provider in order until one succeeds -4. **Success Response**: Returns the response from the first successful provider -5. **Complete Failure**: If all providers fail, returns the original error from the primary provider - -Each fallback attempt is treated as a fresh request, so all your configured plugins (semantic caching, governance rules, monitoring) apply to whichever provider ultimately handles the request. - -## Implementation Examples - - - - -```bash -# Chat completion with multiple fallbacks -curl -X POST http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "openai/gpt-4o-mini", - "messages": [ - { - "role": "user", - "content": "Explain quantum computing in simple terms" - } - ], - "fallbacks": [ - "anthropic/claude-3-5-sonnet-20241022", - "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" - ], - "max_tokens": 1000, - "temperature": 0.7 - }' -``` - -**Response (from whichever provider succeeded):** -```json -{ - "id": "chatcmpl-123", - "object": "chat.completion", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Quantum computing is like having a super-powered calculator..." - }, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 150, - "total_tokens": 162 - }, - "extra_fields": { - "provider": "anthropic", - "latency": 1.2 - } -} -``` - - - - - -```go -package main - -import ( - "context" - "fmt" - "github.com/maximhq/bifrost" - "github.com/maximhq/bifrost/core/schemas" -) - -func chatWithFallbacks(client *bifrost.Bifrost) { - ctx := context.Background() - - // Chat request with multiple fallbacks - response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: []schemas.ChatMessage{ - { - Role: schemas.ChatMessageRoleUser, - Content: schemas.ChatMessageContent{ - ContentStr: bifrost.Ptr("Explain quantum computing in simple terms"), - }, - }, - }, - // Fallback chain: OpenAI → Anthropic → Bedrock - Fallbacks: []schemas.Fallback{ - { - Provider: schemas.Anthropic, - Model: "claude-3-5-sonnet-20241022", - }, - { - Provider: schemas.Bedrock, - Model: "anthropic.claude-3-sonnet-20240229-v1:0", - }, - }, - Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(1000), - Temperature: bifrost.Ptr(0.7), - }, - }) - - if err != nil { - fmt.Printf("All providers failed: %v\n", err) - return - } - - // Success! Response came from whichever provider worked - fmt.Printf("Response from %s: %s\n", - response.ExtraFields.Provider, - *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) -} -``` - - - - - -## Real-World Scenarios - -**Scenario 1: Rate Limiting** -- Primary: OpenAI hits rate limit → Fallback: Anthropic succeeds -- Your application continues without interruption - -**Scenario 2: Model Unavailability** -- Primary: Specific model unavailable → Fallback: Different provider with similar model -- Seamless transition to equivalent capability - -**Scenario 3: Provider Outage** -- Primary: Provider experiencing downtime → Fallback: Alternative provider -- Business continuity maintained - -**Scenario 4: Cost Optimization** -- Primary: Premium model for quality → Fallback: Cost-effective alternative if budget exceeded -- Governance rules can trigger fallbacks based on usage - -## Fallback Behavior Details - -**What Triggers Fallbacks:** -- Network connectivity issues -- Provider API errors (500, 502, 503, 504) -- Rate limiting (429 errors) -- Model unavailability -- Request timeouts -- Authentication failures - -**What Preserves Original Error:** -- Request validation errors (malformed requests) -- Plugin-enforced blocks (governance violations) -- Certain provider-specific errors marked as non-retryable - -**Plugin Execution:** -When a fallback is triggered, the fallback request is treated as completely new: -- Semantic cache checks run again (different provider might have cached responses) -- Governance rules apply to the new provider -- Logging captures the fallback attempt -- All configured plugins execute fresh for the fallback provider - -**Plugin Fallback Control:** -Plugins can control whether fallbacks should be triggered based on their specific logic. For example: -- A custom plugin might prevent fallbacks for certain types of errors -- Security plugins might disable fallbacks for compliance reasons - -When a plugin determines that fallbacks should not be attempted, it can prevent the fallback mechanism entirely, ensuring the original error is returned immediately. - -This ensures consistent behavior regardless of which provider ultimately handles your request, while giving plugins full control over the fallback decision process. And you can always know which provider handled your request via `extra_fields`. diff --git a/docs/features/observability/default.mdx b/docs/features/observability/default.mdx index 4572aec2eb..ad60c254c4 100644 --- a/docs/features/observability/default.mdx +++ b/docs/features/observability/default.mdx @@ -24,12 +24,45 @@ Bifrost traces comprehensive information for every request, without any changes - **Input Messages**: Complete conversation history and user prompts - **Model Parameters**: Temperature, max tokens, tools, and all other parameters - **Provider Context**: Which provider and model handled the request +- **Prompt Tracking**: When the [Prompts plugin](/features/prompt-repository/prompts-plugin) is active, the log captures the selected prompt name, version number, and ID for full traceability ### **Response Data** - **Output Messages**: AI responses, tool calls, and function results - **Performance Metrics**: Latency and token usage - **Status Information**: Success or error details +### **Retry & Key Selection** v1.5.0-prerelease4+ + +When Bifrost retries a request (rate-limit or network error) the following fields are recorded: + +| Field | Meaning | +|-------|---------| +| `selected_key_id` / `selected_key_name` | The API key that **successfully** served the request. `null` when all attempts failed — use `attempt_trail` to see which keys were tried. | +| `number_of_retries` | Total number of attempts minus one. **Does not indicate which key was used on each attempt.** | +| `attempt_trail` | Ordered array of every attempt, with key used and failure reason. `fail_reason` is `null` on the final attempt. | + +**Example `attempt_trail`** — two rate-limit rotations then success on a third key: + +```json +"attempt_trail": [ + { "attempt": 0, "key_id": "key-a", "key_name": "Key A", "fail_reason": "rate_limit_error" }, + { "attempt": 1, "key_id": "key-b", "key_name": "Key B", "fail_reason": "rate_limit_error" }, + { "attempt": 2, "key_id": "key-c", "key_name": "Key C", "fail_reason": null } +] +``` + +Network-error retries reuse the same key; only rate-limit errors rotate to a different key: + +```json +"attempt_trail": [ + { "attempt": 0, "key_id": "key-a", "key_name": "Key A", "fail_reason": "network_error" }, + { "attempt": 1, "key_id": "key-a", "key_name": "Key A", "fail_reason": "rate_limit_error" }, + { "attempt": 2, "key_id": "key-b", "key_name": "Key B", "fail_reason": null } +] +``` + +`attempt_trail` is `null` / absent when the request succeeded on the first try without retries. + ### **Custom Metadata** - **Logging Headers**: Capture configured request headers (e.g., `X-Tenant-ID`) into log metadata - **Ad-hoc Headers**: Any `x-bf-lh-*` prefixed header is automatically captured into metadata diff --git a/docs/features/observability/prometheus.mdx b/docs/features/observability/prometheus.mdx index 6c6df6a24f..3512783e8d 100644 --- a/docs/features/observability/prometheus.mdx +++ b/docs/features/observability/prometheus.mdx @@ -178,6 +178,7 @@ The following metrics are available from both the `/metrics` endpoint and Push G | `bifrost_cache_hits_total` | Counter | Cache hits by type | | `bifrost_stream_first_token_latency_seconds` | Histogram | Time to first token (streaming) | | `bifrost_stream_inter_token_latency_seconds` | Histogram | Inter-token latency (streaming) | +| `bifrost_key_rotation_events_total` | Counter | Per-attempt retry/rotation events with key identifiers (see below) v1.5.0-prerelease4+ | ### Default Labels @@ -187,12 +188,42 @@ All Bifrost metrics include these labels: - `model` - Model identifier - `method` - Request type (chat, completion, embedding, etc.) - `virtual_key_id` / `virtual_key_name` - Virtual key identifiers -- `selected_key_id` / `selected_key_name` - Actual key used -- `number_of_retries` - Retry count +- `selected_key_id` / `selected_key_name` - API key that successfully served the request (`""` when all attempts failed) +- `number_of_retries` - Total attempts minus one (across all keys) - `fallback_index` - Fallback position - `team_id` / `team_name` - Team identifiers (if governance enabled) - `customer_id` / `customer_name` - Customer identifiers (if governance enabled) + + **v1.5.0-prerelease4+**: `selected_key_id` / `selected_key_name` are only populated when the request succeeds. On final errors both are empty — use `bifrost_key_rotation_events_total` or the `attempt_trail` log field to see which keys were tried. + + +### Key Rotation Events v1.5.0-prerelease4+ + +`bifrost_key_rotation_events_total` is incremented once per **failed attempt** (not per request), giving you time-series visibility into retry pressure: + +| Label | Values | Description | +|-------|--------|-------------| +| `provider` | e.g. `openai` | LLM provider | +| `requested_model` | e.g. `gpt-4o` | Model as requested (before any alias resolution) | +| `key_id` | UUID | The provider API key that failed on this attempt | +| `key_name` | string | Human-readable name of the provider API key | +| `fail_reason` | error type string | Provider error type (e.g. `rate_limit_error`, `network_error`) | + +**Example queries:** + +```promql +# Rate-limit events per provider over time +sum by (provider, fail_reason) ( + rate(bifrost_key_rotation_events_total[5m]) +) + +# Which specific keys are hitting rate limits most often +topk(5, sum by (provider, key_name, fail_reason) ( + rate(bifrost_key_rotation_events_total{fail_reason="rate_limit_error"}[1h]) +)) +``` + --- ## Push Gateway Setup diff --git a/docs/features/prompt-repository/playground.mdx b/docs/features/prompt-repository/playground.mdx index 30c2b0df9a..879006c7d2 100644 --- a/docs/features/prompt-repository/playground.mdx +++ b/docs/features/prompt-repository/playground.mdx @@ -85,7 +85,12 @@ The playground uses a simple **three-panel layout**: |------|---------| | **Sidebar (left)** | Browse prompts, manage folders, and organize items | | **Playground (center)** | Build and test your prompt messages | -| **Settings (right)** | Configure provider, model, API key, variables, and parameters | +| **Settings (right)** | Configure provider, model, API key, variables, parameters, and deployments | + +The settings panel is organized into collapsible sections: + +- **Configuration** — Provider, model, API key, variables, and model parameters +- **Deployments** — Prompt deployment strategies and traffic routing (enterprise) ![Workspace Layout](../../media/prompt-repo-layout.png) diff --git a/docs/features/prompt-repository/prompts-plugin.mdx b/docs/features/prompt-repository/prompts-plugin.mdx index 50fd6def32..a06d817245 100644 --- a/docs/features/prompt-repository/prompts-plugin.mdx +++ b/docs/features/prompt-repository/prompts-plugin.mdx @@ -29,13 +29,13 @@ The **Prompts** plugin connects the [Prompt Repository](/features/prompt-reposit ```mermaid flowchart TB Client([Client]) --> Gateway[Bifrost HTTP] - Gateway --> PreHook["HTTP transport pre-hook:
copy bf-prompt-id / bf-prompt-version to context"] + Gateway --> PreHook["HTTP transport pre-hook:
copy x-bf-prompt-id / x-bf-prompt-version to context"] PreHook --> PreLLM["PreLLM hook:
resolve version, merge params,
prepend template messages"] PreLLM --> Provider[Provider] ``` -1. **Transport (HTTP):** Incoming headers `bf-prompt-id` and `bf-prompt-version` are copied onto the Bifrost context (header name matching is case-insensitive). -2. **Resolve:** The plugin looks up the prompt and the requested version. If **`bf-prompt-version` is omitted**, the prompt’s **latest committed version** is used. +1. **Transport (HTTP):** Incoming headers `x-bf-prompt-id` and `x-bf-prompt-version` are copied onto the Bifrost context (header name matching is case-insensitive). +2. **Resolve:** The plugin looks up the prompt and the requested version. If **`x-bf-prompt-version` is omitted**, the prompt’s **latest committed version** is used. 3. **Parameters:** Version `model` parameters are merged into the request; any field already set on the request wins. 4. **Messages:** Messages from the committed version are **prepended** to `messages` (chat) or `input` (responses). Your request body adds the user turn(s) after the template. @@ -47,8 +47,8 @@ If the prompt ID is missing, the plugin does nothing and the request passes thro | Header | Required | Description | |--------|----------|-------------| -| `bf-prompt-id` | Yes, to enable injection | UUID of the prompt in the repository. | -| `bf-prompt-version` | No | **Integer version number** (e.g. `3` for v3). If omitted, the **latest** committed version for that prompt is used. | +| `x-bf-prompt-id` | Yes, to enable injection | UUID of the prompt in the repository. | +| `x-bf-prompt-version` | No | **Integer version number** (e.g. `3` for v3). If omitted, the **latest** committed version for that prompt is used. | Invalid or unknown IDs / versions are logged as warnings; the request is **not** failed by the plugin (it proceeds without template injection). @@ -61,7 +61,7 @@ Use the same JSON body as a normal chat request. Only the headers select the tem ```bash curl -X POST http://localhost:8080/v1/chat/completions \ -H "Content-Type: application/json" \ - -H "bf-prompt-id: YOUR-PROMPT-UUID" \ + -H "x-bf-prompt-id: YOUR-PROMPT-UUID" \ -H "x-bf-vk: sk-bf-your-virtual-key" \ -d '{ "model": "openai/gpt-5.4", @@ -76,13 +76,13 @@ curl -X POST http://localhost:8080/v1/chat/completions \ ![Commit Version with Stream enabled in the playground](../../media/prompt-plugin-version-commit.png) -When you commit a version from the playground, **Stream** is saved in that version’s model parameters. The example `curl` above does not set `"stream": true` in the JSON body, but if the committed version was saved with streaming enabled (as in the screenshot), the merged parameters still include `stream: true`, so the request is handled as **streaming** even though the client did not send `stream` explicitly. +When you commit a version from the playground, the model parameters (temperature, max tokens, etc.) are saved with it. These parameters are merged into the outgoing request, with client-supplied values taking precedence. ![LLM log for the same request showing Type: Chat Stream](../../media/prompt-plugin-llm-log.png) -In **Logs**, that run shows **Type: Chat Stream** and the full conversation: the committed **system** template, your **user** message from the request body, and the assistant reply. +In **Logs**, that run shows the full conversation: the committed **system** template, your **user** message from the request body, and the assistant reply. The log also displays the **Selected Prompt** name and version number for easy traceability. -The provider receives the **stored** messages from the prompt version, checks if the request is streaming or non-streaming, applies the additional model parameters from the request and prepends the messages from the prompt version followed by your user message. +The provider receives the merged model parameters from both the prompt version and the client request, with the messages from the committed version prepended before the client’s messages. --- @@ -91,8 +91,8 @@ The provider receives the **stored** messages from the prompt version, checks if ```bash curl -X POST http://localhost:8080/v1/responses \ -H "Content-Type: application/json" \ - -H "bf-prompt-id: YOUR-PROMPT-UUID" \ - -H "bf-prompt-version: 4" \ + -H "x-bf-prompt-id: YOUR-PROMPT-UUID" \ + -H "x-bf-prompt-version: 4" \ -H "x-bf-vk: sk-bf-your-virtual-key" \ -d '{ "model": "openai/gpt-5-nano-2025-08-07", @@ -104,7 +104,7 @@ curl -X POST http://localhost:8080/v1/responses \ ## Streaming -If the committed version’s **model parameters** include `"stream": true`, the plugin may set streaming on the HTTP transport so behavior matches the saved version. Client-side `stream` flags still interact with the merged parameters as usual. +Streaming is controlled entirely by the client request. If you want streaming, set `"stream": true` in the request body. The plugin merges model parameters from the committed version (request values take precedence), but does **not** override the transport-level streaming mode. --- @@ -118,12 +118,26 @@ The plugin keeps an in-memory cache of prompts and versions (loaded with a small For embedded Bifrost (Go SDK), register the plugin with `prompts.Init` and a **config store** that implements the prompt tables API. The default resolver reads the same logical keys from `BifrostContext`: -- `prompts.PromptIDKey` (`bf-prompt-id`) -- `prompts.PromptVersionKey` (`bf-prompt-version`) +- `prompts.PromptIDKey` (`x-bf-prompt-id`) +- `prompts.PromptVersionKey` (`x-bf-prompt-version`) Set them on the context you pass to `ChatCompletion` / `Responses` if you are not going through the HTTP transport hooks. -For advanced routing (for example, choosing a prompt from governance metadata), implement `prompts.PromptResolver` in `plugins/prompts/main.go` and use **`prompts.InitWithResolver`**. +For advanced routing (for example, choosing a prompt from governance metadata), implement `prompts.PromptResolver` and use **`prompts.InitWithResolver`**. The interface is: + +```go +type PromptResolver interface { + Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, err error) +} +``` + +Return an empty `promptID` to skip injection for a request. Return `versionNumber == 0` to use the prompt's **latest** committed version; any positive integer selects that specific version. + +After injection, the plugin sets the following context keys (read by the logging plugin to populate log fields): + +- `schemas.BifrostContextKeySelectedPromptID` — UUID of the applied prompt +- `schemas.BifrostContextKeySelectedPromptName` — Display name of the prompt +- `schemas.BifrostContextKeySelectedPromptVersion` — Version number as a string (e.g. `"3"`) --- diff --git a/docs/features/retries-and-fallbacks.mdx b/docs/features/retries-and-fallbacks.mdx new file mode 100644 index 0000000000..4d42afafde --- /dev/null +++ b/docs/features/retries-and-fallbacks.mdx @@ -0,0 +1,389 @@ +--- +title: "Retries & Fallbacks" +description: "Automatic retry with exponential backoff and provider failover. Retries handle transient errors within a provider; fallbacks switch to a different provider when all retries are exhausted." +icon: "list-check" +--- + +## Overview + +Bifrost provides two complementary layers of resilience: + +- **Retries** — When a provider returns a transient error (network issue, rate limit, 5xx), Bifrost automatically retries the same request against the same provider with exponential backoff. On rate-limit errors, it can also rotate to a different API key from your pool. +- **Fallbacks** — When the primary provider fails after exhausting all retries, Bifrost moves on to the next provider in your fallback chain. Each fallback provider gets its own full retry budget. + +Together, they let you build LLM-powered applications that stay up through rate limits, transient outages, and even full provider failures — with no changes required in your application code. + +--- + +## Retries + +### How retries work + +When a request fails with a retryable error, Bifrost: + +1. Waits using **exponential backoff with jitter** before the next attempt +2. Retries the request against the same provider +3. On **rate-limit errors** (`429`): rotates to a different API key from the pool (if multiple keys are configured) so fresh capacity is used +4. On **network/server errors** (`5xx`, DNS, connection refused): reuses the same key — these are transient server issues, not per-key capacity problems +5. Continues until the request succeeds or `max_retries` is exhausted + +### Backoff formula + +``` +backoff = min(retry_backoff_initial × 2^attempt, retry_backoff_max) × jitter(0.8–1.2) +``` + +With the defaults of `retry_backoff_initial = 500ms` and `retry_backoff_max = 5000ms`: + +| Attempt | Base backoff | With jitter (approx.) | +|---------|-------------|----------------------| +| 1st retry | 500 ms | 400–600 ms | +| 2nd retry | 1000 ms | 800 ms–1.2 s | +| 3rd retry | 2000 ms | 1.6–2.4 s | +| 4th retry | 4000 ms | 3.2–4.8 s | +| 5th+ retry | 5000 ms (capped) | 4–5 s | + +### What triggers a retry + +| Condition | Retried? | Key rotation? | +|-----------|----------|---------------| +| Network error (DNS, connection refused) | Yes | No — same key reused | +| `5xx` server errors (500, 502, 503, 504) | Yes | No — same key reused | +| Rate limit (`429` or rate-limit message pattern) | Yes | Yes — next key from pool | +| Request validation error | No | — | +| Plugin-enforced block | No | — | +| Cancelled request | No | — | + +### Configuring retries + +Retries are configured per-provider in `network_config`. The defaults are `max_retries: 0` (no retries), `retry_backoff_initial: 500` ms, and `retry_backoff_max: 5000` ms. + + + + +Navigate to **Providers**, select a provider, and open the **Network Config** section. + +Set: +- **Max Retries** — number of additional attempts after the first failure (e.g. `3`) +- **Retry Backoff Initial** — starting backoff in milliseconds (e.g. `500`) +- **Retry Backoff Max** — maximum backoff cap in milliseconds (e.g. `5000`) + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "max_retries": 3, + "retry_backoff_initial": 500, + "retry_backoff_max": 5000 + } +}' +``` + + + + +```go +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch provider { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + MaxRetries: 3, + RetryBackoffInitial: 500 * time.Millisecond, + RetryBackoffMax: 5 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { "name": "openai-key-1", "value": "env.OPENAI_KEY_1", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-2", "value": "env.OPENAI_KEY_2", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-3", "value": "env.OPENAI_KEY_3", "models": ["*"], "weight": 1.0 } + ], + "network_config": { + "max_retries": 3, + "retry_backoff_initial": 500, + "retry_backoff_max": 5000 + } + } + } +} +``` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `max_retries` | integer | `0` | Number of additional attempts after the first failure | +| `retry_backoff_initial` | integer (ms) | `500` | Starting backoff duration in milliseconds | +| `retry_backoff_max` | integer (ms) | `5000` | Maximum backoff cap in milliseconds | + + + + +### Key rotation on rate limits + + +Key rotation on retries requires **v1.5.0-prerelease4 or later**. + + +When you configure multiple API keys for a provider, Bifrost automatically rotates to a fresh key when a rate-limit error is encountered — so retries are not wasted repeating a request with a key that has already hit its limit. + +```json +{ + "providers": { + "openai": { + "keys": [ + { "name": "openai-key-1", "value": "env.OPENAI_KEY_1", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-2", "value": "env.OPENAI_KEY_2", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-3", "value": "env.OPENAI_KEY_3", "models": ["*"], "weight": 1.0 } + ], + "network_config": { + "max_retries": 5 + } + } + } +} +``` + +With 3 keys and `max_retries: 5`, Bifrost cycles through all three keys twice before giving up. Once all keys in the pool have been tried, it resets and starts a fresh weighted round. + + +Key rotation on rate limits only applies when `max_retries > 0` and more than one key is configured for the provider. With a single key, all retries reuse that key. + + +--- + +## Fallbacks + +Fallbacks provide automatic failover to a different provider when the primary fails after exhausting all its retries. Each fallback is tried in order until one succeeds. + +### How fallbacks work + +1. **Primary attempt**: Tries your configured provider with its full retry budget +2. **Fallback decision**: If the primary fails (and the error is retryable at the provider level), Bifrost moves to the first fallback +3. **Sequential fallbacks**: Each fallback provider also gets its own full retry budget +4. **First success wins**: Returns the response from the first provider that succeeds +5. **All fail**: Returns the original error from the primary provider. Exception: if a plugin on a fallback provider sets `AllowFallbacks = false` on the error (e.g. a security or compliance plugin that should halt the chain regardless of remaining fallbacks), Bifrost stops immediately and returns that fallback's error rather than continuing to the next provider or returning the primary error. + +Each fallback is treated as a completely fresh request — all configured plugins (semantic caching, governance, logging) run again for the fallback provider. + +### Implementation + + + + +Pass a `fallbacks` array in the request body. Each entry specifies a `provider/model` string: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [ + { + "role": "user", + "content": "Explain quantum computing in simple terms" + } + ], + "fallbacks": [ + "anthropic/claude-3-5-sonnet-20241022", + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + ], + "max_tokens": 1000, + "temperature": 0.7 + }' +``` + +The response `extra_fields.provider` tells you which provider actually served the request: + +```json +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Quantum computing is like having a super-powered calculator..." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 12, + "completion_tokens": 150, + "total_tokens": 162 + }, + "extra_fields": { + "provider": "anthropic", + "latency": 1.2 + } +} +``` + + + + +```go +package main + +import ( + "context" + "fmt" + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/core/schemas" +) + +func chatWithFallbacks(client *bifrost.Bifrost) { + ctx := context.Background() + + response, err := client.ChatCompletionRequest( + schemas.NewBifrostContext(ctx, schemas.NoDeadline), + &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Explain quantum computing in simple terms"), + }, + }, + }, + // Fallback chain: OpenAI → Anthropic → Bedrock + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-5-sonnet-20241022"}, + {Provider: schemas.Bedrock, Model: "anthropic.claude-3-sonnet-20240229-v1:0"}, + }, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(1000), + Temperature: bifrost.Ptr(0.7), + }, + }, + ) + + if err != nil { + fmt.Printf("All providers failed: %v\n", err) + return + } + + fmt.Printf("Response from %s: %s\n", + response.ExtraFields.Provider, + *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) +} +``` + + + + +--- + +## How retries and fallbacks work together + +The two mechanisms form a nested resilience loop. Retries run inside each provider attempt; fallbacks run across providers once retries are exhausted. + +```mermaid +sequenceDiagram + participant App + participant Bifrost + participant Primary as Primary Provider + participant FB1 as Fallback 1 + participant FB2 as Fallback 2 + + App->>Bifrost: Request (primary + fallbacks) + + rect rgb(220, 235, 250) + note over Bifrost,Primary: Primary provider attempt (with retries) + Bifrost->>Primary: Attempt 1 + Primary-->>Bifrost: 429 Rate Limit + note over Bifrost: Backoff + rotate key + Bifrost->>Primary: Attempt 2 (different key) + Primary-->>Bifrost: 503 Unavailable + note over Bifrost: Backoff + Bifrost->>Primary: Attempt 3 + Primary-->>Bifrost: 503 Unavailable + note over Bifrost: max_retries exhausted + end + + rect rgb(235, 250, 220) + note over Bifrost,FB1: Fallback 1 attempt (with its own retries) + Bifrost->>FB1: Attempt 1 + FB1-->>Bifrost: 500 Server Error + note over Bifrost: Backoff + Bifrost->>FB1: Attempt 2 + FB1-->>Bifrost: ✓ Success + end + + Bifrost-->>App: Response (from Fallback 1) +``` + +**Key point:** each provider in the chain — primary and every fallback — gets its own full `max_retries` budget. A primary configured with `max_retries: 3` and two fallbacks each also configured with `max_retries: 3` means up to 12 total attempts before giving up. + + +The retry budget is set per-provider in `network_config`. If your fallback providers have different retry configurations, each will use their own settings. + + +--- + +## Real-world scenarios + +**Scenario 1: Rate limiting with key rotation** + +OpenAI key 1 hits its rate limit. Bifrost rotates to key 2 on the next retry — no fallback needed, the request succeeds within the same provider. + +**Scenario 2: Provider outage** + +OpenAI is experiencing downtime (returning `503`). Bifrost retries with the same key (transient server issue), exhausts `max_retries`, then fails over to Anthropic. Anthropic succeeds on the first attempt. + +**Scenario 3: Cascading failure** + +Both primary and first fallback are down. Bifrost works through each provider's retry budget sequentially until the second fallback succeeds. + +**Scenario 4: Cost-sensitive fallback** + +Primary: a premium model for quality. Fallback: a cost-effective alternative. Governance rules can trigger a budget-exceeded error on the primary, which cascades into the fallback chain. + +--- + +## Plugin execution + +When a fallback is triggered, the fallback request is treated as completely new: + +- Semantic cache checks run again (the fallback provider may have a cached response) +- Governance rules apply to the new provider +- Logging captures the fallback attempt separately +- All configured plugins execute fresh for each provider in the chain + +**Plugin fallback control:** Plugins can prevent fallbacks from being triggered for specific error types. For example, a security plugin might disable fallbacks for compliance reasons. When a plugin sets `AllowFallbacks = false` on the error, the fallback chain is skipped entirely and the original error is returned immediately. + +--- + +## Next steps + +- **[Keys Management](./keys-management)** — Configure multiple API keys per provider to enable key rotation on retries +- **[Governance](./governance/virtual-keys)** — Use virtual keys and routing rules to control which providers are used +- **[Observability](./observability/default)** — Track retry counts and fallback usage in your logs diff --git a/docs/features/telemetry.mdx b/docs/features/telemetry.mdx index 29741cd309..f816e30b4f 100644 --- a/docs/features/telemetry.mdx +++ b/docs/features/telemetry.mdx @@ -65,8 +65,8 @@ Base Labels: - `routing_engines_used`: Comma-separated routing engines used ("routing-rule", "governance", "loadbalancing") - `routing_rule_id`: Routing rule ID that matched the request - `routing_rule_name`: Routing rule name that matched the request -- `selected_key_id`: Selected key ID -- `selected_key_name`: Selected key name +- `selected_key_id`: ID of the key that successfully served the request (`null` on final errors) +- `selected_key_name`: Name of the key that successfully served the request (`null` on final errors) - `number_of_retries`: Number of retries - `fallback_index`: Fallback index (0 for first attempt, 1 for second attempt, etc.) - custom labels: Custom labels configured in the Bifrost configuration diff --git a/docs/integrations/vaults/aws-secrets-manager.mdx b/docs/integrations/vaults/aws-secrets-manager.mdx deleted file mode 100644 index 1913b486c2..0000000000 --- a/docs/integrations/vaults/aws-secrets-manager.mdx +++ /dev/null @@ -1,39 +0,0 @@ ---- -title: "AWS Secrets Manager" -description: "AWS Secrets Manager integration for secret management in Bifrost. Cloud-native API key storage and automatic synchronization with AWS Secrets Manager." -icon: "aws" ---- - -Bifrost integrates with [AWS Secrets Manager](https://aws.amazon.com/secrets-manager/) for cloud-native secret storage, allowing you to store provider API keys and virtual keys in your AWS environment and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to AWS Secrets Manager: - -```json -{ - "vault": { - "type": "aws_secrets_manager", - "region": "us-east-1", - "access_key_id": "${AWS_ACCESS_KEY_ID}", - "secret_access_key": "${AWS_SECRET_ACCESS_KEY}", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"aws_secrets_manager"` to use AWS Secrets Manager. | -| `region` | `string` | The AWS region where your secrets are stored (e.g., `"us-east-1"`). | -| `access_key_id` | `string` | AWS access key ID for authentication. Supports environment variable interpolation via `${AWS_ACCESS_KEY_ID}`. | -| `secret_access_key` | `string` | AWS secret access key for authentication. Supports environment variable interpolation via `${AWS_SECRET_ACCESS_KEY}`. | -| `sync_interval` | `string` | How often Bifrost syncs keys from AWS Secrets Manager. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/integrations/vaults/azure-key-vault.mdx b/docs/integrations/vaults/azure-key-vault.mdx deleted file mode 100644 index 50fe336737..0000000000 --- a/docs/integrations/vaults/azure-key-vault.mdx +++ /dev/null @@ -1,41 +0,0 @@ ---- -title: "Azure Key Vault" -description: "Azure Key Vault integration for secret management in Bifrost. Secure API key storage and automatic synchronization with Microsoft Azure Key Vault." -icon: "microsoft" ---- - -Bifrost integrates with [Azure Key Vault](https://azure.microsoft.com/en-us/products/key-vault) for secret management in Microsoft cloud environments, allowing you to store provider API keys and virtual keys in your Azure infrastructure and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to Azure Key Vault: - -```json -{ - "vault": { - "type": "azure_key_vault", - "vault_url": "https://your-keyvault.vault.azure.net/", - "client_id": "${AZURE_CLIENT_ID}", - "client_secret": "${AZURE_CLIENT_SECRET}", - "tenant_id": "${AZURE_TENANT_ID}", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"azure_key_vault"` to use Azure Key Vault. | -| `vault_url` | `string` | The full URL of your Azure Key Vault instance (e.g., `"https://your-keyvault.vault.azure.net/"`). | -| `client_id` | `string` | Azure AD application (client) ID for authentication. Supports environment variable interpolation via `${AZURE_CLIENT_ID}`. | -| `client_secret` | `string` | Azure AD client secret for authentication. Supports environment variable interpolation via `${AZURE_CLIENT_SECRET}`. | -| `tenant_id` | `string` | Azure AD tenant ID for authentication. Supports environment variable interpolation via `${AZURE_TENANT_ID}`. | -| `sync_interval` | `string` | How often Bifrost syncs keys from Azure Key Vault. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/integrations/vaults/google-secret-manager.mdx b/docs/integrations/vaults/google-secret-manager.mdx deleted file mode 100644 index f630509512..0000000000 --- a/docs/integrations/vaults/google-secret-manager.mdx +++ /dev/null @@ -1,37 +0,0 @@ ---- -title: "Google Secret Manager" -description: "Google Secret Manager integration for secret management in Bifrost. Secure API key storage and automatic synchronization with Google Cloud's Secret Manager." -icon: "google" ---- - -Bifrost integrates with [Google Cloud Secret Manager](https://cloud.google.com/secret-manager) for secure key storage, allowing you to store provider API keys and virtual keys in your Google Cloud environment and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to Google Secret Manager: - -```json -{ - "vault": { - "type": "google_secret_manager", - "project_id": "your-project-id", - "credentials_file": "/path/to/service-account.json", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"google_secret_manager"` to use Google Secret Manager. | -| `project_id` | `string` | The Google Cloud project ID where your secrets are stored. | -| `credentials_file` | `string` | Path to the Google Cloud service account JSON credentials file used for authentication. | -| `sync_interval` | `string` | How often Bifrost syncs keys from Google Secret Manager. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/integrations/vaults/hashicorp-vault.mdx b/docs/integrations/vaults/hashicorp-vault.mdx deleted file mode 100644 index e91e2e80eb..0000000000 --- a/docs/integrations/vaults/hashicorp-vault.mdx +++ /dev/null @@ -1,39 +0,0 @@ ---- -title: "HashiCorp Vault" -description: "HashiCorp Vault integration for secret management in Bifrost. Centralized API key storage and automatic synchronization with your HashiCorp Vault instance." -icon: "vault" ---- - -Bifrost integrates with [HashiCorp Vault](https://www.vaultproject.io/) for centralized secret management, allowing you to store provider API keys and virtual keys in your existing Vault infrastructure and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to your HashiCorp Vault instance: - -```json -{ - "vault": { - "type": "hashicorp", - "address": "https://vault.company.com:8200", - "token": "${VAULT_TOKEN}", - "mount": "secret", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"hashicorp"` to use HashiCorp Vault. | -| `address` | `string` | The full URL of your HashiCorp Vault server, including the port. | -| `token` | `string` | Authentication token for Vault access. Supports environment variable interpolation via `${VAULT_TOKEN}`. | -| `mount` | `string` | The secrets engine mount path in Vault (e.g., `"secret"` for the default KV secrets engine). | -| `sync_interval` | `string` | How often Bifrost syncs keys from Vault. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/media/user-provisioning/zitadel-add-role.png b/docs/media/user-provisioning/zitadel-add-role.png deleted file mode 100644 index f00212f6d2..0000000000 Binary files a/docs/media/user-provisioning/zitadel-add-role.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-add-user-select-key.png b/docs/media/user-provisioning/zitadel-add-user-select-key.png deleted file mode 100644 index ba8d8e52a9..0000000000 Binary files a/docs/media/user-provisioning/zitadel-add-user-select-key.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-auth-method.png b/docs/media/user-provisioning/zitadel-create-app-auth-method.png deleted file mode 100644 index c27a6e5772..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-auth-method.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-namne.png b/docs/media/user-provisioning/zitadel-create-app-namne.png deleted file mode 100644 index 7e220ce193..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-namne.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-uri.png b/docs/media/user-provisioning/zitadel-create-app-uri.png deleted file mode 100644 index 8796e77ec5..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-uri.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-role-assignemnt.png b/docs/media/user-provisioning/zitadel-role-assignemnt.png deleted file mode 100644 index 5f233eb436..0000000000 Binary files a/docs/media/user-provisioning/zitadel-role-assignemnt.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-select-project.png b/docs/media/user-provisioning/zitadel-select-project.png deleted file mode 100644 index 824c48dd83..0000000000 Binary files a/docs/media/user-provisioning/zitadel-select-project.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-token-config.png b/docs/media/user-provisioning/zitadel-token-config.png deleted file mode 100644 index 1354a62830..0000000000 Binary files a/docs/media/user-provisioning/zitadel-token-config.png and /dev/null differ diff --git a/docs/migration-guides/v1.5.0.mdx b/docs/migration-guides/v1.5.0.mdx index 7d1933aaac..5d31c01b97 100644 --- a/docs/migration-guides/v1.5.0.mdx +++ b/docs/migration-guides/v1.5.0.mdx @@ -462,6 +462,61 @@ result.ResolvedModel // actual model identifier used by the provider --- +## Breaking Change 12: `selected_key_id` Cleared on Terminal Retry Failures + +With the introduction of multi-key retry rotation, `selected_key_id` (and `selected_key_name`) in the request context are **cleared when all retry attempts fail**. Previously, these fields always reflected the key that was selected for the request, even on error. + +The `attempt_trail` is now the authoritative record of every key tried and why each attempt failed. + +### What changed + +| Field | Before | After | +|---|---|---| +| `selected_key_id` | Always set, even on error | Empty string when all retries exhausted | +| `selected_key_name` | Always set, even on error | Empty string when all retries exhausted | +| `attempt_trail` | Not present | Array of `{ key_id, key_name, fail_reason }` per attempt | + +### Impact on logging and telemetry plugins + +The built-in **logging plugin** writes `selected_key_id` and `selected_key_name` directly to each log record. For multi-key requests that exhaust all retries, both fields will be empty in the stored log entry. The `attempt_trail` column captures the full per-attempt key history and is the correct field to use for failure attribution. + +The built-in **telemetry plugin** emits `selected_key_id` and `selected_key_name` as span attributes. For exhausted-retry failures these attributes will be empty strings on the error span. The `attempt_trail` span attribute contains the full rotation history. + +If you run a custom plugin or downstream log consumer that filters or groups by `selected_key_id` to track which key caused a failure, you must update it to handle the empty-string case and read from `attempt_trail` when attribution is needed. + +### How to update + +**If you read `selected_key_id` from plugin context to attribute failed requests:** + +**Before:** +```go +keyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string) +// keyID was always populated, even on error +``` + +**After:** +```go +// Populated on success (or for single-key / pinned / sticky flows on error): +keyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string) + +// For full attribution across all retry attempts (including failures): +if trail, ok := ctx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok { + for _, record := range trail { + // record.KeyID, record.KeyName, record.FailReason + } +} +``` + +**If you consume `selected_key_id` from the logging REST API:** + +The `selected_key_id` field on a `LogEntry` may now be an empty string when the request failed after exhausting all retries. Use `attempt_trail` for the full per-attempt key history. + + +Single-key, pinned (`x-bf-key-id` / `x-bf-key-name`), and session-sticky requests are unaffected — they never rotate keys, so `selected_key_id` remains populated on failure for those flows. + + +--- + ## Opting Out: `version: 1` Compatibility Mode If you are not ready to adopt the new deny-by-default semantics, you can add a single field to `config.json` to restore v1.4.x behavior for all allow-list fields loaded from that file: @@ -548,6 +603,10 @@ Replace `ExtraFields.ModelRequested` with `ExtraFields.OriginalModelRequested` ( Replace `.Model` with `.RequestedModel` (and optionally `.ResolvedModel`) on any `StreamAccumulatorResult` usage. + + +If your code reads `selected_key_id` / `selected_key_name` from the request context or log entries to attribute failed requests, add a null/empty check and fall back to `attempt_trail` for the full per-attempt key history. + --- diff --git a/docs/openapi/openapi.json b/docs/openapi/openapi.json index 2a8e0d7a89..59f420fd09 100644 --- a/docs/openapi/openapi.json +++ b/docs/openapi/openapi.json @@ -28815,6 +28815,10 @@ "type": "integer", "description": "Number of members in the team" }, + "virtual_key_count": { + "type": "integer", + "description": "Number of virtual keys assigned to the team" + }, "created_at": { "type": "string", "format": "date-time" @@ -28977,6 +28981,10 @@ "type": "integer", "description": "Number of members in the team" }, + "virtual_key_count": { + "type": "integer", + "description": "Number of virtual keys assigned to the team" + }, "created_at": { "type": "string", "format": "date-time" @@ -41604,6 +41612,12 @@ "in": "header", "name": "x-api-key", "description": "API key authentication via the `x-api-key` header.\nVirtual keys (prefixed with `sk-bf-`) can also be passed here.\n" + }, + "GoogleApiKeyAuth": { + "type": "apiKey", + "in": "header", + "name": "x-goog-api-key", + "description": "Google API key authentication via the `x-goog-api-key` header.\nVirtual keys (prefixed with `sk-bf-`) can also be passed here.\n" } }, "parameters": { @@ -53884,6 +53898,8 @@ }, "virtual_keys": { "type": "array", + "nullable": true, + "description": "Virtual keys assigned to this team. This field may be omitted or returned as null in some responses (for example, when a team is embedded inside a virtual-key response) to avoid nested `virtual_keys` recursion.\n", "items": { "$ref": "#/components/schemas/VirtualKey" } diff --git a/docs/openapi/schemas/management/governance.yaml b/docs/openapi/schemas/management/governance.yaml index bd04aecee7..f057682620 100644 --- a/docs/openapi/schemas/management/governance.yaml +++ b/docs/openapi/schemas/management/governance.yaml @@ -376,6 +376,11 @@ Team: $ref: '#/Budget' virtual_keys: type: array + nullable: true + description: > + Virtual keys assigned to this team. This field may be omitted or returned as null in some + responses (for example, when a team is embedded inside a virtual-key response) to avoid + nested `virtual_keys` recursion. items: $ref: '#/VirtualKey' profile: diff --git a/docs/openapi/schemas/management/users.yaml b/docs/openapi/schemas/management/users.yaml index 46db148f3e..aa9cb504d1 100644 --- a/docs/openapi/schemas/management/users.yaml +++ b/docs/openapi/schemas/management/users.yaml @@ -155,6 +155,9 @@ TeamObject: member_count: type: integer description: Number of members in the team + virtual_key_count: + type: integer + description: Number of virtual keys assigned to the team created_at: type: string format: date-time diff --git a/docs/providers/supported-providers/azure.mdx b/docs/providers/supported-providers/azure.mdx index 46f41a8840..56a576191c 100644 --- a/docs/providers/supported-providers/azure.mdx +++ b/docs/providers/supported-providers/azure.mdx @@ -37,6 +37,544 @@ Azure is a cloud provider offering access to OpenAI and Anthropic models through **Azure-specific**: Batch operations and Text Completions are not supported by Azure OpenAI Service. Responses API uses preview API version and is available for both OpenAI and Anthropic models. +--- + +## Setup & Configuration + +Azure requires an endpoint URL, deployment mappings, and authentication configuration. Four authentication methods are supported. + + +The `aliases` field (mapping model names to Azure deployment IDs) requires **v1.5.0-prerelease2 or later**. On v1.4.x, use `deployments` inside `azure_key_config` instead — see the [v1.5.0 Migration Guide](/migration-guides/v1.5.0#breaking-change-9-provider-deployments-removed-migrate-to-aliases) for details. + + +### 1. Managed Identity + +Leave `value` and all Entra ID fields empty. Bifrost calls `azidentity.NewDefaultAzureCredential(nil)`, which auto-detects the system-assigned or user-assigned managed identity on Azure VMs, App Service, AKS, and Azure Container Instances. No credentials need to be stored or rotated. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Default Credential"** +4. Set **Endpoint**: Your Azure OpenAI resource URL (e.g., `https://your-org.openai.azure.com`) +5. Set **API Version** (Optional): e.g., `2024-10-21` +6. Configure **Aliases**: Map model names to deployment IDs (e.g., `gpt-4o` → `my-gpt4o-deployment`) +7. Save + +Ensure a managed identity with Cognitive Services permissions is attached to the Azure resource running Bifrost. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (Managed Identity — leave value empty) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-managed-identity", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `azure_key_config`: +```json +"azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21", + "deployments": { + "gpt-4o": "my-gpt4o-deployment" + } +} +``` + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-managed-identity", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty — Bifrost uses the managed identity + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + APIVersion: schemas.NewEnvVar("2024-10-21"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 2. Default Credential Chain (DefaultAzureCredential) + +Leave `value` and all Entra ID fields empty. Bifrost calls `azidentity.NewDefaultAzureCredential(nil)` — the same function as the Managed Identity section — which tries credential sources in this order: + +1. Environment variables (`AZURE_CLIENT_ID` + `AZURE_CLIENT_SECRET` + `AZURE_TENANT_ID`, or certificate/username variants) +2. Workload Identity (AKS with Workload Identity Federation) +3. Managed Identity (Azure VMs, App Service, AKS, Container Instances) +4. Azure CLI (`az login`) +5. Azure Developer CLI (`azd auth login`) + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Default Credential"** +4. Set **Endpoint**: Your Azure OpenAI resource URL (e.g., `https://your-org.openai.azure.com`) +5. Set **API Version** (Optional): e.g., `2024-10-21` +6. Configure **Aliases**: Map model names to deployment IDs +7. Save + +Ensure the appropriate credentials are available in the environment where Bifrost runs — set `AZURE_CLIENT_ID`, `AZURE_CLIENT_SECRET`, `AZURE_TENANT_ID` env vars, or run `az login`. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (Default Credential Chain — leave value empty) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-default-chain", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `azure_key_config`. + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-default-chain", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty — Bifrost uses DefaultAzureCredential + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + APIVersion: schemas.NewEnvVar("2024-10-21"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 3. Azure Entra ID (Service Principal) + +Set `client_id`, `client_secret`, and `tenant_id` to authenticate with a Service Principal. This takes priority over API key and managed identity. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Entra ID (Service Principal)"** +4. Set **Client ID**: Your Azure Entra ID client ID +5. Set **Client Secret**: Your Azure Entra ID client secret +6. Set **Tenant ID**: Your Azure Entra ID tenant ID +7. Set **Endpoint**: Your Azure OpenAI resource URL +8. Set **API Version** (Optional): e.g., `2024-08-01-preview` +9. Configure **Aliases**: Map model names to deployment IDs +10. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (Service Principal) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-entra-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + "claude-3-5-sonnet": "my-claude-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "client_id": "env.AZURE_CLIENT_ID", + "client_secret": "env.AZURE_CLIENT_SECRET", + "tenant_id": "env.AZURE_TENANT_ID", + "scopes": ["https://cognitiveservices.azure.com/.default"], + "api_version": "2024-08-01-preview" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Move the model mappings from `aliases` into `azure_key_config.deployments`. + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-entra-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + "claude-3-5-sonnet": "my-claude-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "client_id": "env.AZURE_CLIENT_ID", + "client_secret": "env.AZURE_CLIENT_SECRET", + "tenant_id": "env.AZURE_TENANT_ID", + "scopes": ["https://cognitiveservices.azure.com/.default"], + "api_version": "2024-08-01-preview" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty for Service Principal auth + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + "claude-3-5-sonnet": "my-claude-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + ClientID: schemas.NewEnvVar(os.Getenv("AZURE_CLIENT_ID")), + ClientSecret: schemas.NewEnvVar(os.Getenv("AZURE_CLIENT_SECRET")), + TenantID: schemas.NewEnvVar(os.Getenv("AZURE_TENANT_ID")), + Scopes: []string{"https://cognitiveservices.azure.com/.default"}, + APIVersion: schemas.NewEnvVar("2024-08-01-preview"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +**Required Azure roles:** +- OpenAI models: `Cognitive Services OpenAI User` +- Anthropic models: `Cognitive Services AI Services User` + +### 4. Direct Authentication (API Key) + +Provide the Azure API key in the `value` field. Use this for simple setups without managed identity or Service Principal. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"API Key"** +4. Set **API Key**: Your Azure API key +5. Set **Endpoint**: Your Azure OpenAI resource URL +6. Set **API Version** (Optional): e.g., `2024-10-21` +7. Configure **Aliases**: Map model names to deployment IDs +8. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (API Key auth) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-api-key", + "value": "env.AZURE_API_KEY", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Move the model mappings from `aliases` into `azure_key_config.deployments`. + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-api-key", + "value": "env.AZURE_API_KEY", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: *schemas.NewEnvVar("env.AZURE_OPENAI_KEY"), + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + APIVersion: schemas.NewEnvVar("2024-10-21"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + + +**Authentication precedence:** (1) Entra ID if `client_id`, `client_secret`, and `tenant_id` are all set; (2) API key if `value` is non-empty; (3) DefaultAzureCredential (managed identity) if neither is provided. + + +**`azure_key_config` fields:** + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `endpoint` | Yes | — | Azure OpenAI resource endpoint URL | +| `api_version` | No | `2024-10-21` | Azure API version | +| `client_id` | No | — | Entra ID client ID (Service Principal auth) | +| `client_secret` | No | — | Entra ID client secret (Service Principal auth) | +| `tenant_id` | No | — | Entra ID tenant ID (Service Principal auth) | +| `scopes` | No | `["https://cognitiveservices.azure.com/.default"]` | OAuth scopes for token requests | + +**Key-level fields:** + +| Field | Required | Description | +|-------|----------|-------------| +| `aliases` | No | Map model names to Azure deployment IDs (v1.5.0-prerelease2+) | +| `value` | No | Azure API key (leave empty for Entra ID or managed identity) | +| `models` | Yes | Models this key can serve; use `["*"]` to allow all | + +--- + ## Beta Headers For Anthropic models on Azure, Bifrost validates `anthropic-beta` headers and drops unsupported headers from the request. Azure supports most Anthropic beta features. @@ -577,17 +1115,4 @@ Azure routes video generation to OpenAI's Sora models via the Azure OpenAI-compa ## Setup & Configuration -Azure requires endpoint URLs, deployment mappings, and API version configuration. For detailed instructions on setting up Azure authentication, see the quickstart guides: - - - - -See **[Provider-Specific Authentication - Azure](../../quickstart/gateway/provider-configuration#azure)** in the Gateway Quickstart for configuration steps using Web UI, API, or config.json. - - - - -See **[Provider-Specific Authentication - Azure](../../quickstart/go-sdk/provider-configuration#azure)** in the Go SDK Quickstart for programmatic configuration examples. - - - +See the [Setup & Configuration](#setup--configuration) section at the top of this page for authentication instructions and full configuration examples. diff --git a/docs/providers/supported-providers/bedrock.mdx b/docs/providers/supported-providers/bedrock.mdx index 52ff79ca3f..a6bd44bcb7 100644 --- a/docs/providers/supported-providers/bedrock.mdx +++ b/docs/providers/supported-providers/bedrock.mdx @@ -52,6 +52,479 @@ AWS Bedrock supports multiple model families (Claude, Nova, Mistral, Llama, Cohe **Limitations**: Images must be in base64 or data URI format (remote URLs not supported). Text completion streaming is not supported. +--- + +## Setup & Configuration + +Bifrost signs every Bedrock request with AWS Signature Version 4 (SigV4). Four credential methods are supported — choose the one that matches your deployment environment. + + +The `aliases` field (mapping model names to inference profile IDs, ARNs, or deployment identifiers) requires **v1.5.0-prerelease2 or later**. On v1.4.x, use `deployments` inside `bedrock_key_config` instead — see the [v1.5.0 Migration Guide](/migration-guides/v1.5.0#breaking-change-9-provider-deployments-removed-migrate-to-aliases) for details. + + +### 1. Explicit Credentials + +Provide `access_key` and `secret_key` directly. Optionally include `session_token` for temporary credentials. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Explicit Credentials"** +4. Set **Access Key**: Your AWS access key ID +5. Set **Secret Key**: Your AWS secret access key +6. Set **Session Token** (Optional): For temporary/assumed credentials +7. Set **Region**: e.g., `us-east-1` +8. Configure **Aliases**: Map model names to inference profile IDs +9. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (Explicit Credentials) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-key", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "session_token": "env.AWS_SESSION_TOKEN", + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `bedrock_key_config`: +```json +"bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "region": "us-east-1", + "deployments": { + "claude-3-5-sonnet": "arn:aws:bedrock:us-east-1::foundation-model/..." + } +} +``` + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-key", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "session_token": "env.AWS_SESSION_TOKEN", + "region": "us-east-1" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `bedrock_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + }, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), + SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), + SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), + Region: schemas.NewEnvVar("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 2. IAM Role (IRSA / Instance Profile) + +Leave `access_key` and `secret_key` empty. Bifrost calls AWS SDK's `config.LoadDefaultConfig()`, which runs the full credential chain. This section covers IAM-role-based sources: EKS IRSA (`AWS_WEB_IDENTITY_TOKEN_FILE` + `AWS_ROLE_ARN`), ECS task role, and EC2 instance profile (IMDS). No static credentials are stored anywhere. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"IAM Role (Inherited)"** +4. Set **Region**: e.g., `us-east-1` +5. Configure **Aliases** if needed +6. Save + +Ensure your workload has an IAM role with Bedrock permissions attached (via IRSA, ECS task role, or EC2 instance profile). + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (IAM Role — no credentials) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-iam", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `bedrock_key_config`. + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-iam", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + // Leave Value empty — Bifrost uses the IAM role bound to the workload + Aliases: schemas.KeyAliases{ + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + }, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + // Leave AccessKey and SecretKey empty — resolved from IRSA/instance profile + Region: schemas.NewEnvVar("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 3. Default Credential Chain + +Leave `access_key` and `secret_key` empty. Bifrost calls the same `config.LoadDefaultConfig()` as the IAM Role section — the full AWS SDK v2 credential chain runs. Use this when credentials come from environment variables or shared credential files rather than an attached IAM role. + +Full AWS SDK v2 credential resolution order: +1. Environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN`) +2. Web Identity Token File / IRSA (`AWS_WEB_IDENTITY_TOKEN_FILE` + `AWS_ROLE_ARN`) +3. Shared credentials and config files (`~/.aws/credentials`, `~/.aws/config`) +4. EC2 instance metadata / IMDS +5. ECS container credentials + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"IAM Role (Inherited)"** — this covers both assigned IAM roles and the default credential chain (env vars, `~/.aws/credentials`) +4. Set **Region**: e.g., `us-east-1` +5. Configure **Aliases** if needed +6. Save + +Ensure `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` are set in the environment where Bifrost runs, or that `~/.aws/credentials` is configured. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (Default Credential Chain — no credentials) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-default-chain", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-default-chain", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + // Leave AccessKey and SecretKey empty — resolved from env vars or ~/.aws/credentials + Region: schemas.NewEnvVar("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 4. STS AssumeRole + +Set `role_arn` to assume an IAM role before signing requests. STS AssumeRole can layer on top of either explicit credentials or the default credential chain. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"IAM Role (Inherited)"** or **"Explicit Credentials"** — AssumeRole fields are available on both tabs +4. Set **Region**: e.g., `us-east-1` +5. Set **Assume Role ARN**: The IAM role ARN (e.g., `arn:aws:iam::123456789012:role/BedrockRole`) +6. Set **External ID** (Optional): Required when the role's trust policy demands it +7. Set **Session Name** (Optional): Identifies the session in CloudTrail (default: `bifrost-session`) +8. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (STS AssumeRole) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-assume-role", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "role_arn": "env.AWS_ROLE_ARN", + "external_id": "env.AWS_EXTERNAL_ID", + "session_name": "bifrost-session", + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-assume-role", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "role_arn": "env.AWS_ROLE_ARN", + "external_id": "env.AWS_EXTERNAL_ID", + "session_name": "bifrost-session", + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + RoleARN: schemas.NewEnvVar("env.AWS_ROLE_ARN"), + ExternalID: schemas.NewEnvVar("env.AWS_EXTERNAL_ID"), // optional + RoleSessionName: schemas.NewEnvVar("bifrost-session"), // optional + Region: bifrost.Ptr("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + + +AssumeRole requires a valid source identity — it works when credentials are available via explicit `access_key`/`secret_key` or via the default credential chain. If no credentials are available from either source, AssumeRole will fail. + + +**`bedrock_key_config` fields:** + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `region` | Yes | — | AWS region (e.g., `us-east-1`) | +| `access_key` | No | — | AWS access key ID | +| `secret_key` | No | — | AWS secret access key | +| `session_token` | No | — | AWS session token (for temporary credentials) | +| `arn` | No | — | ARN prefix for constructing inference profile URLs (see [Inference Profiles](#inference-profiles--arn-configuration)) | +| `role_arn` | No | — | IAM role ARN for STS AssumeRole | +| `external_id` | No | — | External ID for AssumeRole (when required by trust policy) | +| `session_name` | No | `bifrost-session` | Session name for AssumeRole CloudTrail logs | + +**Key-level fields:** + +| Field | Required | Description | +|-------|----------|-------------| +| `aliases` | No | Map model names to inference profile IDs or Bedrock model IDs (v1.5.0-prerelease2+) | +| `models` | Yes | Models this key can serve; use `["*"]` to allow all | + +--- + ## Beta Headers For Claude models on Bedrock, Bifrost validates `anthropic-beta` headers and drops unsupported headers from the request. @@ -1271,7 +1744,7 @@ Set `role_arn` to assume an IAM role before signing requests. AssumeRole require | `session_name` | No | `bifrost-session` | Identifies the session in CloudTrail logs | -## Setup & Configuration +## Inference Profiles & ARN Configuration ### How to Use ARNs and Application Inference Profiles @@ -1321,21 +1794,6 @@ When using AWS Bedrock inference profiles or application inference profiles, you } ``` -For detailed instructions on setting up AWS Bedrock authentication including credentials, IAM roles, regions, and deployment mapping, see the quickstart guides: - - - - -See **[Provider-Specific Authentication - AWS Bedrock](../../quickstart/gateway/provider-configuration#aws-bedrock)** in the Gateway Quickstart for configuration steps using Web UI, API, or config.json. - - - - -See **[Provider-Specific Authentication - AWS Bedrock](../../quickstart/go-sdk/provider-configuration#aws-bedrock)** in the Go SDK Quickstart for programmatic configuration examples. - - - - ### Endpoints - **Runtime API**: `bedrock-runtime.{region}.amazonaws.com/model/{path}` diff --git a/docs/providers/supported-providers/vertex.mdx b/docs/providers/supported-providers/vertex.mdx index 4ea108d733..1d19b05984 100644 --- a/docs/providers/supported-providers/vertex.mdx +++ b/docs/providers/supported-providers/vertex.mdx @@ -7,6 +7,7 @@ icon: "v" ## Overview Vertex AI is Google's unified ML platform providing access to Google's Gemini models, Anthropic Claude models, and other third-party LLMs through a single API. Bifrost performs conversions including: + - **Multi-model support** - Unified interface for Gemini, Anthropic, and third-party models - **OAuth2 authentication** - Service account credentials with automatic token refresh - **Project and region management** - Automatic endpoint construction from GCP project/region @@ -17,28 +18,411 @@ Vertex AI is Google's unified ML platform providing access to Google's Gemini mo ### Supported Operations -| Operation | Non-Streaming | Streaming | Endpoint | -|-----------|---------------|-----------|----------| -| Chat Completions | ✅ | ✅ | `/generate` | -| Responses API | ✅ | ✅ | `/messages` | -| Embeddings | ✅ | - | `/embeddings` | -| Image Generation | ✅ | - | `/generateContent` or `/predict` (Imagen) | -| Image Edit | ✅ | - | `/generateContent` or `/predict` (Imagen) | -| Video Generation | ✅ | - | `/predictLongRunning` (Veo models only) | -| Image Variation | ❌ | - | Not supported | -| List Models | ✅ | - | `/models` | -| Text Completions | ❌ | ❌ | - | -| Speech (TTS) | ❌ | ❌ | - | -| Transcriptions (STT) | ❌ | ❌ | - | -| Files | ❌ | ❌ | - | -| Batch | ❌ | ❌ | - | +| Operation | Non-Streaming | Streaming | Endpoint | +| -------------------- | ------------- | --------- | ----------------------------------------- | +| Chat Completions | ✅ | ✅ | `/generate` | +| Responses API | ✅ | ✅ | `/messages` | +| Embeddings | ✅ | - | `/embeddings` | +| Image Generation | ✅ | - | `/generateContent` or `/predict` (Imagen) | +| Image Edit | ✅ | - | `/generateContent` or `/predict` (Imagen) | +| Video Generation | ✅ | - | `/predictLongRunning` (Veo models only) | +| Image Variation | ❌ | - | Not supported | +| List Models | ✅ | - | `/models` | +| Text Completions | ❌ | ❌ | - | +| Speech (TTS) | ❌ | ❌ | - | +| Transcriptions (STT) | ❌ | ❌ | - | +| Files | ❌ | ❌ | - | +| Batch | ❌ | ❌ | - | **Unsupported Operations** (❌): Text Completions, Speech, Transcriptions, Files, and Batch are not supported by Vertex AI. These return `UnsupportedOperationError`. **Vertex-specific**: Endpoints vary by model type. Responses API available for both Gemini and Anthropic models. + + + +--- + +## Setup & Configuration + +Vertex AI requires Google Cloud project configuration and authentication credentials. Three authentication methods are supported. + + + The `aliases` field (mapping model names to fine-tuned model IDs or endpoint + identifiers) requires **v1.5.0-prerelease2 or later**. On v1.4.x, use + `deployments` inside `vertex_key_config` instead — see the [v1.5.0 Migration + Guide](/migration-guides/v1.5.0#breaking-change-9-provider-deployments-removed-migrate-to-aliases) + for details. + + +### 1. Service Account JSON (Recommended for Production) + +Provide a credential JSON string in `auth_credentials`. The JSON must contain a `type` field. Supported types: `service_account` (most common), `impersonated_service_account`, `authorized_user`, `external_account`, `external_account_authorized_user`. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Google Vertex"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Service Account (JSON)"** +4. Set **Project ID**: Your Google Cloud project ID +5. Set **Region**: e.g., `us-central1` +6. Set **Auth Credentials**: Paste your service account JSON or reference an env var (e.g., `env.VERTEX_CREDENTIALS`) +7. Configure **Aliases**: Map model names to fine-tuned model IDs (if using fine-tuned models) +8. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "vertex"}' + +# Step 2: Create a key (Service Account JSON) +curl -X POST http://localhost:8080/api/providers/vertex/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "vertex-sa-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Use `deployments` inside `vertex_key_config` instead of the top-level `aliases` field for fine-tuned model mappings. + + + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-sa-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + } + ] + } + } +} +``` + + + On **v1.4.x**, use `deployments` inside `vertex_key_config` instead of the + top-level `aliases` field for fine-tuned model mappings. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Vertex: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty when using service account credentials + Models: []string{"*"}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), + Region: *schemas.NewEnvVar("us-central1"), + AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"), // full service account JSON + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 2. Application Default Credentials + +Leave `auth_credentials` empty. Bifrost calls `google.FindDefaultCredentials()` — Google's ADC library — which resolves credentials in this order: + +1. `GOOGLE_APPLICATION_CREDENTIALS` env var (path to a JSON credential file) +2. Application default credential file (`~/.config/gcloud/application_default_credentials.json`, written by `gcloud auth application-default login`) +3. GCE/GKE/Cloud Run/App Engine metadata server (attached service account or Workload Identity) + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Google Vertex"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Service Account (Attached)"** +4. Set **Project ID**: Your Google Cloud project ID +5. Set **Region**: e.g., `us-central1` +6. Configure **Aliases** if needed +7. Save + +Ensure `GOOGLE_APPLICATION_CREDENTIALS` is set in your environment, or that Workload Identity / gcloud is configured. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "vertex"}' + +# Step 2: Create a key (Application Default Credentials) +curl -X POST http://localhost:8080/api/providers/vertex/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "vertex-adc-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "" + } + }' +``` + + +**On v1.4.x**, pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. + + + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-adc-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Vertex: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, + Models: []string{"*"}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), + Region: *schemas.NewEnvVar("us-central1"), + // Leave AuthCredentials empty — uses Application Default Credentials + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 3. API Key (Gemini and Fine-Tuned Models Only) + +Set `value` to your Vertex API key. API key authentication is supported only for Gemini models and fine-tuned Gemini models. For Anthropic models on Vertex, use Service Account or Application Default Credentials. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Google Vertex"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"API Key"** +4. Set **API Key**: Your Vertex AI API key +5. Set **Project ID**: Your Google Cloud project ID +6. Set **Region**: e.g., `us-central1` +7. Set **Project Number** (Optional): Required only when using fine-tuned models +8. Configure **Aliases**: Map short names to fine-tuned model IDs (e.g., `my-model` → `123456789`) +9. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "vertex"}' + +# Step 2: Create a key (API Key — Gemini + fine-tuned models) +curl -X POST http://localhost:8080/api/providers/vertex/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "vertex-api-key", + "value": "env.VERTEX_API_KEY", + "models": ["gemini-pro", "gemini-2.0-flash", "my-fine-tuned-model"], + "weight": 1.0, + "aliases": { + "my-fine-tuned-model": "123456789" + }, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", + "region": "us-central1" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `vertex_key_config`: +```json +"vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "deployments": { + "my-fine-tuned-model": "123456789" + } +} +``` + + + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-api-key", + "value": "env.VERTEX_API_KEY", + "models": ["gemini-pro", "gemini-2.0-flash", "my-fine-tuned-model"], + "weight": 1.0, + "aliases": { + "my-fine-tuned-model": "123456789" + }, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", + "region": "us-central1" + } + } + ] + } + } +} +``` + + + On **v1.4.x**, use `deployments` inside `vertex_key_config` instead of the + top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Vertex: + return []schemas.Key{ + { + Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), // only when using Gemini or fine-tuned models + Models: []string{"gemini-pro", "gemini-2.0-flash", "my-fine-tuned-model"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "my-fine-tuned-model": "123456789", + }, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), + ProjectNumber: *schemas.NewEnvVar("env.VERTEX_PROJECT_NUMBER"), // required for fine-tuned models + Region: *schemas.NewEnvVar("us-central1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + + + Vertex AI support for fine-tuned models is currently in beta. Requests to + non-Gemini fine-tuned models may fail, so please test and report any issues. +**`vertex_key_config` fields:** + +| Field | Required | Description | +| ------------------ | -------- | ------------------------------------------------------ | +| `project_id` | Yes | Google Cloud project ID | +| `region` | Yes | GCP region (e.g., `us-central1`, `eu-west1`, `global`) | +| `auth_credentials` | No | Service account JSON string (leave empty for ADC) | +| `project_number` | No | GCP project number (required for fine-tuned models) | + +**Key-level fields:** + +| Field | Required | Description | +| --------- | -------- | ----------------------------------------------------------------------------------------- | +| `value` | No | Vertex API key (Gemini and fine-tuned models only; leave empty for Service Account / ADC) | +| `aliases` | No | Map model names to fine-tuned model IDs or endpoint identifiers (v1.5.0-prerelease2+) | +| `models` | Yes | Models this key can serve; use `["*"]` to allow all | + +--- + ## Beta Headers For Anthropic models on Vertex AI, Bifrost validates `anthropic-beta` headers and drops unsupported headers from the request. @@ -50,7 +434,10 @@ For Anthropic models on Vertex AI, Bifrost validates `anthropic-beta` headers an You can override these defaults per provider via the **Beta Headers** tab in provider configuration or via [`beta_header_overrides`](/quickstart/gateway/provider-configuration#beta-header-overrides). See the full support matrix in the [Anthropic provider docs](/providers/supported-providers/anthropic#beta-headers). - Vertex AI Beta Headers configuration tab showing supported and unsupported Anthropic beta features with override options + Vertex AI Beta Headers configuration tab showing supported and unsupported Anthropic beta features with override options --- @@ -61,9 +448,9 @@ You can override these defaults per provider via the **Beta Headers** tab in pro ### Core Parameter Mapping -| Parameter | Vertex Handling | Notes | -|-----------|---|-------| -| `model` | Maps to Vertex model ID | Region-specific endpoint constructed automatically | +| Parameter | Vertex Handling | Notes | +| ---------------- | ------------------------- | ---------------------------------------------------- | +| `model` | Maps to Vertex model ID | Region-specific endpoint constructed automatically | | All other params | Model-specific conversion | Converted per underlying provider (Gemini/Anthropic) | ### Key Configuration @@ -81,6 +468,7 @@ The key configuration for Vertex requires Google Cloud credentials: ``` **Configuration Details**: + - `project_id` - GCP project ID (required) - `region` - GCP region for API endpoints (required) - Examples: `us-central1`, `us-west1`, `eu-west1`, `global` @@ -89,8 +477,9 @@ The key configuration for Vertex requires Google Cloud credentials: ### Authentication Methods 1. **Service Account JSON** (recommended for production) + ```json - {"auth_credentials": "{full-service-account-json}"} + { "auth_credentials": "{full-service-account-json}" } ``` 2. **Application Default Credentials** (for local development) @@ -137,12 +526,12 @@ Refer to [Anthropic documentation](/providers/supported-providers/anthropic) for The region determines the API endpoint: -| Region | Endpoint | Purpose | -|--------|----------|---------| -| `us-central1` | `us-central1-aiplatform.googleapis.com` | US Central | -| `us-west1` | `us-west1-aiplatform.googleapis.com` | US West | -| `eu-west1` | `eu-west1-aiplatform.googleapis.com` | Europe West | -| `global` | `aiplatform.googleapis.com` | Global (no region prefix) | +| Region | Endpoint | Purpose | +| ------------- | --------------------------------------- | ------------------------- | +| `us-central1` | `us-central1-aiplatform.googleapis.com` | US Central | +| `us-west1` | `us-west1-aiplatform.googleapis.com` | US West | +| `eu-west1` | `eu-west1-aiplatform.googleapis.com` | Europe West | +| `global` | `aiplatform.googleapis.com` | Global (no region prefix) | Availability varies by region. Check [GCP documentation](https://cloud.google.com/vertex-ai/docs/general/locations) for model availability. @@ -163,12 +552,12 @@ The Responses API is available for both Anthropic (Claude) and Gemini models on ### Core Parameter Mapping -| Parameter | Vertex Handling | Notes | -|-----------|---|-------| -| `instructions` | Becomes system message | Model-specific conversion | -| `input` | Converted to messages | String or array support | -| `max_output_tokens` | Model-specific field mapping | Gemini vs Anthropic conversion | -| All other params | Model-specific conversion | Converted per underlying provider | +| Parameter | Vertex Handling | Notes | +| ------------------- | ---------------------------- | --------------------------------- | +| `instructions` | Becomes system message | Model-specific conversion | +| `input` | Converted to messages | String or array support | +| `max_output_tokens` | Model-specific field mapping | Gemini vs Anthropic conversion | +| All other params | Model-specific conversion | Converted per underlying provider | ### Gemini Models @@ -177,6 +566,7 @@ For Gemini models, conversion follows Gemini's Responses API format. ### Anthropic Models (Claude) For Anthropic models, conversion follows Anthropic's message format: + - `instructions` becomes system message - `reasoning` mapped to `thinking` structure @@ -234,9 +624,9 @@ Embeddings are supported for Gemini and other models that support embedding gene ### Core Parameters -| Parameter | Vertex Mapping | Notes | -|-----------|---|-------| -| `input` | `instances[].content` | Text to embed | +| Parameter | Vertex Mapping | Notes | +| ------------ | --------------------------------- | -------------------- | +| `input` | `instances[].content` | Text to embed | | `dimensions` | `parameters.outputDimensionality` | Optional output size | ### Advanced Parameters @@ -287,11 +677,11 @@ resp, err := client.EmbeddingRequest(schemas.NewBifrostContext(ctx, schemas.NoDe #### Embedding Parameters -| Parameter | Type | Description | -|-----------|------|-------------| -| `task_type` | string | Task type hint: `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`, `SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING` (optional) | -| `title` | string | Optional title to help model produce better embeddings (used with task_type) | -| `autoTruncate` | boolean | Auto-truncate input to max tokens (defaults to true) | +| Parameter | Type | Description | +| -------------- | ------- | ------------------------------------------------------------------------------------------------------------------------- | +| `task_type` | string | Task type hint: `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`, `SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING` (optional) | +| `title` | string | Optional title to help model produce better embeddings (used with task_type) | +| `autoTruncate` | boolean | Auto-truncate input to max tokens (defaults to true) | ### Task Type Effects @@ -322,6 +712,7 @@ Embeddings response includes vectors and truncation information: ``` **Response Fields**: + - `values` - Embedding vector as floats - `statistics.token_count` - Input token count - `statistics.truncated` - Whether input was truncated due to length @@ -336,11 +727,11 @@ Image Generation is supported for Gemini and Imagen on Vertex AI. The provider a ### Core Parameter Mapping -| Parameter | Vertex Handling | Notes | -|-----------|---|-------| -| `model` | Mapped to deployment/model identifier | Model type detected automatically | -| `prompt` | Model-specific conversion | Converted per underlying provider (Gemini/Imagen) | -| All other params | Model-specific conversion | Converted per underlying provider | +| Parameter | Vertex Handling | Notes | +| ---------------- | ------------------------------------- | ------------------------------------------------- | +| `model` | Mapped to deployment/model identifier | Model type detected automatically | +| `prompt` | Model-specific conversion | Converted per underlying provider (Gemini/Imagen) | +| All other params | Model-specific conversion | Converted per underlying provider | ### Model Type Detection @@ -418,29 +809,27 @@ Image generation streaming is not supported by Vertex AI. # 5. Image Edit - -Requests use **multipart/form-data**, not JSON. - +Requests use **multipart/form-data**, not JSON. Image Edit is supported for Gemini and Imagen models on Vertex AI. The provider automatically routes to the appropriate format based on the model type. **Request Parameters** -| Parameter | Type | Required | Notes | -|-----------|------|----------|-------| -| `model` | string | ✅ | Model identifier (must be Gemini or Imagen model) | -| `prompt` | string | ✅ | Text description of the edit | -| `image[]` | binary | ✅ | Image file(s) to edit (supports multiple images) | -| `mask` | binary | ❌ | Mask image file | -| `type` | string | ❌ | Edit type: `"inpainting"`, `"outpainting"`, `"inpaint_removal"`, `"bgswap"` (Imagen only) | -| `n` | int | ❌ | Number of images to generate (1-10) | -| `output_format` | string | ❌ | Output format: `"png"`, `"webp"`, `"jpeg"` | -| `output_compression` | int | ❌ | Compression level (0-100%) | -| `seed` | int | ❌ | Seed for reproducibility (via `ExtraParams["seed"]`) | -| `negative_prompt` | string | ❌ | Negative prompt (via `ExtraParams["negativePrompt"]`) | -| `maskMode` | string | ❌ | Mask mode (via `ExtraParams["maskMode"]`, Imagen only): `"MASK_MODE_USER_PROVIDED"`, `"MASK_MODE_BACKGROUND"`, `"MASK_MODE_FOREGROUND"`, `"MASK_MODE_SEMANTIC"` | -| `dilation` | float | ❌ | Mask dilation (via `ExtraParams["dilation"]`, Imagen only): Range [0, 1] | -| `maskClasses` | int[] | ❌ | Mask classes (via `ExtraParams["maskClasses"]`, Imagen only): For `MASK_MODE_SEMANTIC` | +| Parameter | Type | Required | Notes | +| -------------------- | ------ | -------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `model` | string | ✅ | Model identifier (must be Gemini or Imagen model) | +| `prompt` | string | ✅ | Text description of the edit | +| `image[]` | binary | ✅ | Image file(s) to edit (supports multiple images) | +| `mask` | binary | ❌ | Mask image file | +| `type` | string | ❌ | Edit type: `"inpainting"`, `"outpainting"`, `"inpaint_removal"`, `"bgswap"` (Imagen only) | +| `n` | int | ❌ | Number of images to generate (1-10) | +| `output_format` | string | ❌ | Output format: `"png"`, `"webp"`, `"jpeg"` | +| `output_compression` | int | ❌ | Compression level (0-100%) | +| `seed` | int | ❌ | Seed for reproducibility (via `ExtraParams["seed"]`) | +| `negative_prompt` | string | ❌ | Negative prompt (via `ExtraParams["negativePrompt"]`) | +| `maskMode` | string | ❌ | Mask mode (via `ExtraParams["maskMode"]`, Imagen only): `"MASK_MODE_USER_PROVIDED"`, `"MASK_MODE_BACKGROUND"`, `"MASK_MODE_FOREGROUND"`, `"MASK_MODE_SEMANTIC"` | +| `dilation` | float | ❌ | Mask dilation (via `ExtraParams["dilation"]`, Imagen only): Range [0, 1] | +| `maskClasses` | int[] | ❌ | Mask classes (via `ExtraParams["maskClasses"]`, Imagen only): For `MASK_MODE_SEMANTIC` | --- @@ -454,6 +843,7 @@ Vertex uses the same conversion functions as Gemini: **Model Validation**: Only Gemini and Imagen models are supported. Other models return `ConfigurationError`. **Request Body Processing**: + - All request bodies are converted to `map[string]interface{}` for Vertex API compatibility - The `region` field is removed before sending to Vertex API - For Gemini models, unsupported fields are stripped via `stripVertexGeminiUnsupportedFields()` (removes `id` from function_call and function_response) @@ -466,6 +856,7 @@ Vertex uses the same conversion functions as Gemini: **Endpoint Selection** The provider automatically selects the endpoint based on model type: + - **Gemini models**: `/v1/projects/{projectID}/locations/{region}/publishers/google/models/{model}:generateContent` - **Imagen models**: `/v1/projects/{projectID}/locations/{region}/publishers/google/models/{model}:predict` @@ -509,7 +900,9 @@ Lists models available in the specified project and region with metadata and dep ## Custom vs Non-Custom Models -**Important**: Vertex AI's List Models API **only returns custom fine-tuned models** that have been deployed to your project. It does NOT return standard foundation models (Gemini, Claude, etc.). + **Important**: Vertex AI's List Models API **only returns custom fine-tuned + models** that have been deployed to your project. It does NOT return standard + foundation models (Gemini, Claude, etc.). To provide a complete model listing experience, Bifrost performs **multi-pass model discovery**: @@ -622,53 +1015,49 @@ Model listing is paginated automatically. If more than 100 models exist, `next_p ## Caveats -**Severity**: High -**Behavior**: Both project_id and region required for all operations -**Impact**: Request fails without valid GCP project/region configuration -**Code**: `vertex.go:127-138` + **Severity**: High **Behavior**: Both project_id and region required for all + operations **Impact**: Request fails without valid GCP project/region + configuration **Code**: `vertex.go:127-138` -**Severity**: Medium -**Behavior**: Tokens cached and automatically refreshed when expired -**Impact**: First request slightly slower due to auth; cached for subsequent requests -**Code**: `vertex.go:34-55` + **Severity**: Medium **Behavior**: Tokens cached and automatically refreshed + when expired **Impact**: First request slightly slower due to auth; cached for + subsequent requests **Code**: `vertex.go:34-55` -**Severity**: Medium -**Behavior**: Automatic detection of Anthropic vs Gemini models -**Impact**: Different conversion logic applied transparently -**Code**: `vertex.go` chat/responses endpoints + **Severity**: Medium **Behavior**: Automatic detection of Anthropic vs Gemini + models **Impact**: Different conversion logic applied transparently **Code**: + `vertex.go` chat/responses endpoints -**Severity**: Low -**Behavior**: Responses API automatically routes to Anthropic or Gemini implementation based on model -**Impact**: Different conversion logic applied transparently per model -**Code**: `vertex.go:836-1080` + **Severity**: Low **Behavior**: Responses API automatically routes to + Anthropic or Gemini implementation based on model **Impact**: Different + conversion logic applied transparently per model **Code**: + `vertex.go:836-1080` -**Severity**: Low -**Behavior**: `anthropic_version` always set to `vertex-2023-10-16` for Claude -**Impact**: Cannot override Anthropic version for Claude on Vertex -**Code**: `utils.go:33, 71` + **Severity**: Low **Behavior**: `anthropic_version` always set to + `vertex-2023-10-16` for Claude **Impact**: Cannot override Anthropic version + for Claude on Vertex **Code**: `utils.go:33, 71` -**Severity**: Low -**Behavior**: Vertex returns float64 embeddings, and Bifrost preserves that precision in normalized embedding responses -**Impact**: No precision loss in the `/v1/embeddings` response path -**Code**: `embedding.go:84-91` + **Severity**: Low **Behavior**: Vertex returns float64 embeddings, and Bifrost + preserves that precision in normalized embedding responses **Impact**: No + precision loss in the `/v1/embeddings` response path **Code**: + `embedding.go:84-91` -**Severity**: High -**Behavior**: Vertex AI's List Models API only returns custom fine-tuned models, NOT foundation models -**Impact**: Bifrost performs three-pass discovery to include foundation models from aliases and the key-level `models` allowlist -**Why**: This is a Vertex AI API limitation - foundation models must be explicitly configured -**Code**: `models.go:76-217` + **Severity**: High **Behavior**: Vertex AI's List Models API only returns + custom fine-tuned models, NOT foundation models **Impact**: Bifrost performs + three-pass discovery to include foundation models from aliases and the + key-level `models` allowlist **Why**: This is a Vertex AI API limitation - + foundation models must be explicitly configured **Code**: `models.go:76-217` --- @@ -683,40 +1072,22 @@ Model listing is paginated automatically. If more than 100 models exist, `next_p **Note**: For `global` region, endpoint is `https://aiplatform.googleapis.com/v1/projects/{project}/locations/global/{resource}` -## Setup & Configuration - -Vertex AI requires project configuration, region selection, and Google Cloud authentication. For detailed instructions on setting up Vertex AI, see the quickstart guides: - - - - -See **[Provider-Specific Authentication - Google Vertex](../../quickstart/gateway/provider-configuration#google-vertex)** in the Gateway Quickstart for configuration steps using Web UI, API, or config.json. - - - - -See **[Provider-Specific Authentication - Google Vertex](../../quickstart/go-sdk/provider-configuration#google-vertex)** in the Go SDK Quickstart for programmatic configuration examples. - - - - ---- - ## Video Generation Vertex AI routes video generation through Gemini's Veo models using the `predictLongRunning` endpoint. All parameters are identical to [Gemini Video Generation](/providers/supported-providers/gemini#video-generation). -Only Veo models are supported (e.g., `veo-2.0-generate-001`). Passing a non-Veo model name returns a configuration error. + Only Veo models are supported (e.g., `veo-2.0-generate-001`). Passing a + non-Veo model name returns a configuration error. **Supported Operations** -| Operation | Supported | Notes | -|-----------|-----------|-------| -| Generate | ✅ | `POST /v1/videos` | -| Retrieve | ✅ | `GET /v1/videos/{id}` | -| Download | ✅ | `GET /v1/videos/{id}/content` | -| Delete | ❌ | Not supported | -| List | ❌ | Not supported | -| Remix | ❌ | Not supported | +| Operation | Supported | Notes | +| --------- | --------- | ----------------------------- | +| Generate | ✅ | `POST /v1/videos` | +| Retrieve | ✅ | `GET /v1/videos/{id}` | +| Download | ✅ | `GET /v1/videos/{id}/content` | +| Delete | ❌ | Not supported | +| List | ❌ | Not supported | +| Remix | ❌ | Not supported | diff --git a/docs/quickstart/gateway/provider-configuration.mdx b/docs/quickstart/gateway/provider-configuration.mdx index 597d3b2a1e..db259873f0 100644 --- a/docs/quickstart/gateway/provider-configuration.mdx +++ b/docs/quickstart/gateway/provider-configuration.mdx @@ -84,60 +84,56 @@ curl --location 'http://localhost:8080/api/providers' \ - -Each key in a provider needs to have a unique name. - - +Each key in a provider needs to have a unique name. ```json { - "providers": { - "openai": { - "keys": [ - { - "name": "openai-key", - "value": "env.OPENAI_API_KEY", - "models": ["*"], - "weight": 1.0 - } - ] - }, - "anthropic": { - "keys": [ - { - "name": "anthropic-key", - "value": "env.ANTHROPIC_API_KEY", - "models": ["*"], - "weight": 1.0 - } - ] - }, - "vllm-local": { - "keys": [ - { - "name": "vllm-key", - "value": "dummy", - "models": ["*"], - "weight": 1.0 - } - ], - "network_config": { - "base_url": "http://vllm-endpoint:8000", - "default_request_timeout_in_seconds": 60 - }, - "custom_provider_config": { - "base_provider_type": "openai", - "allowed_requests": { - "chat_completion": true, - "chat_completion_stream": true - } - } + "providers": { + "openai": { + "keys": [ + { + "name": "openai-key", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "anthropic": { + "keys": [ + { + "name": "anthropic-key", + "value": "env.ANTHROPIC_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "vllm-local": { + "keys": [ + { + "name": "vllm-key", + "value": "dummy", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "base_url": "http://vllm-endpoint:8000", + "default_request_timeout_in_seconds": 60 + }, + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": true } + } } + } } ``` - @@ -147,7 +143,12 @@ Each key in a provider needs to have a unique name. -**Air-gapped or self-signed certificates:** If your custom provider uses HTTPS with a self-signed or internal CA certificate, add `"insecure_skip_verify": true` or `"ca_cert_pem": "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----"` to `network_config`. See [Custom Providers - TLS](../../providers/custom-providers#tls-for-self-signed-or-internal-certificates) for details. + **Air-gapped or self-signed certificates:** If your custom provider uses HTTPS + with a self-signed or internal CA certificate, add `"insecure_skip_verify": + true` or `"ca_cert_pem": "-----BEGIN CERTIFICATE-----\n...\n-----END + CERTIFICATE-----"` to `network_config`. See [Custom Providers - + TLS](../../providers/custom-providers#tls-for-self-signed-or-internal-certificates) + for details. ## Making Requests @@ -179,6 +180,7 @@ export COHERE_API_KEY="your-cohere-api-key" ``` **Environment Variable Handling:** + - Use `"value": "env.VARIABLE_NAME"` to reference environment variables - Use `"value": "sk-proj-xxxxxxxxx"` to pass keys directly - All sensitive data is automatically redacted in GET requests and UI responses for security @@ -218,7 +220,7 @@ curl --location 'http://localhost:8080/api/providers' \ }, { "name": "openai-key-2", - "value": "env.OPENAI_API_KEY_2", + "value": "env.OPENAI_API_KEY_2", "models": ["*"], "weight": 0.3 } @@ -232,24 +234,24 @@ curl --location 'http://localhost:8080/api/providers' \ ```json { - "providers": { - "openai": { - "keys": [ - { - "name": "openai-key-1", - "value": "env.OPENAI_API_KEY_1", - "models": ["*"], - "weight": 0.7 - }, - { - "name": "openai-key-2", - "value": "env.OPENAI_API_KEY_2", - "models": ["*"], - "weight": 0.3 - } - ] + "providers": { + "openai": { + "keys": [ + { + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY_1", + "models": ["*"], + "weight": 0.7 + }, + { + "name": "openai-key-2", + "value": "env.OPENAI_API_KEY_2", + "models": ["*"], + "weight": 0.3 } + ] } + } } ``` @@ -304,24 +306,24 @@ curl --location 'http://localhost:8080/api/providers' \ ```json { - "providers": { - "openai": { - "keys": [ - { - "name": "openai-key-1", - "value": "env.OPENAI_API_KEY", - "models": ["gpt-4o", "gpt-4o-mini"], - "weight": 1.0 - }, - { - "name": "openai-key-2", - "value": "env.OPENAI_API_KEY_PREMIUM", - "models": ["o1-preview", "o1-mini"], - "weight": 1.0 - } - ] + "providers": { + "openai": { + "keys": [ + { + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0 + }, + { + "name": "openai-key-2", + "value": "env.OPENAI_API_KEY_PREMIUM", + "models": ["o1-preview", "o1-mini"], + "weight": 1.0 } + ] } + } } ``` @@ -337,7 +339,6 @@ Override the default API endpoint for a provider. This is useful for connecting -