diff --git a/cli/go.mod b/cli/go.mod index 2c1c930bc2..c6a77eaf26 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -46,7 +46,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/arch v0.23.0 // indirect - golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.33.0 // indirect ) diff --git a/cli/go.sum b/cli/go.sum index e6e613043a..bbd592a88f 100644 --- a/cli/go.sum +++ b/cli/go.sum @@ -89,8 +89,8 @@ github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8u github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= diff --git a/core/internal/llmtests/provider_feature_support_test.go b/core/internal/llmtests/provider_feature_support_test.go new file mode 100644 index 0000000000..ec38a7d9fb --- /dev/null +++ b/core/internal/llmtests/provider_feature_support_test.go @@ -0,0 +1,1207 @@ +package llmtests + +import ( + "testing" + "time" + + "github.com/maximhq/bifrost/core/providers/anthropic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestProviderToolValidation verifies that unsupported tools are rejected per provider +func TestProviderToolValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider schemas.ModelProvider + tools []schemas.ResponsesTool + expectErr bool + errSubstr string + }{ + // ── Anthropic (supports everything) ── + { + name: "Anthropic/web_search_allowed", + provider: schemas.Anthropic, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, + }, + { + name: "Anthropic/web_fetch_allowed", + provider: schemas.Anthropic, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + }, + { + name: "Anthropic/code_interpreter_allowed", + provider: schemas.Anthropic, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCodeInterpreter}}, + }, + { + name: "Anthropic/mcp_allowed", + provider: schemas.Anthropic, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMCP}}, + }, + { + name: "Anthropic/computer_use_allowed", + provider: schemas.Anthropic, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeComputerUsePreview}}, + }, + + // ── Vertex (web_search yes, web_fetch/code_exec/MCP no) ── + { + name: "Vertex/web_search_allowed", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, + }, + { + name: "Vertex/web_fetch_rejected", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + expectErr: true, + errSubstr: "web_fetch", + }, + { + name: "Vertex/code_interpreter_rejected", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCodeInterpreter}}, + expectErr: true, + errSubstr: "code_interpreter", + }, + { + name: "Vertex/mcp_rejected", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMCP}}, + expectErr: true, + errSubstr: "mcp", + }, + { + name: "Vertex/computer_use_allowed", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeComputerUsePreview}}, + }, + { + name: "Vertex/bash_allowed", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeLocalShell}}, + }, + { + name: "Vertex/memory_allowed", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMemory}}, + }, + { + name: "Vertex/tool_search_allowed", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeToolSearch}}, + }, + + // ── Bedrock (no web_search, web_fetch, code_exec, MCP) ── + { + name: "Bedrock/web_search_rejected", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, + expectErr: true, + errSubstr: "web_search", + }, + { + name: "Bedrock/web_fetch_rejected", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + expectErr: true, + errSubstr: "web_fetch", + }, + { + name: "Bedrock/code_interpreter_rejected", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCodeInterpreter}}, + expectErr: true, + errSubstr: "code_interpreter", + }, + { + name: "Bedrock/mcp_rejected", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMCP}}, + expectErr: true, + errSubstr: "mcp", + }, + { + name: "Bedrock/computer_use_allowed", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeComputerUsePreview}}, + }, + { + name: "Bedrock/bash_allowed", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeLocalShell}}, + }, + { + name: "Bedrock/memory_allowed", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMemory}}, + }, + + // ── Azure (supports everything like Anthropic) ── + { + name: "Azure/web_search_allowed", + provider: schemas.Azure, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, + }, + { + name: "Azure/web_fetch_allowed", + provider: schemas.Azure, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + }, + { + name: "Azure/code_interpreter_allowed", + provider: schemas.Azure, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCodeInterpreter}}, + }, + { + name: "Azure/mcp_allowed", + provider: schemas.Azure, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMCP}}, + }, + + // ── Function/custom tools always allowed ── + { + name: "Bedrock/function_tool_allowed", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeFunction}}, + }, + { + name: "Vertex/custom_tool_allowed", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCustom}}, + }, + + // ── FileSearch and ImageGeneration (OpenAI-only, rejected on all Anthropic providers) ── + { + name: "Anthropic/file_search_rejected", + provider: schemas.Anthropic, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeFileSearch}}, + expectErr: true, + errSubstr: "file_search", + }, + { + name: "Vertex/file_search_rejected", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeFileSearch}}, + expectErr: true, + errSubstr: "file_search", + }, + { + name: "Bedrock/image_generation_rejected", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeImageGeneration}}, + expectErr: true, + errSubstr: "image_generation", + }, + { + name: "Azure/image_generation_rejected", + provider: schemas.Azure, + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeImageGeneration}}, + expectErr: true, + errSubstr: "image_generation", + }, + + // ── Mixed tools: first unsupported tool triggers error ── + { + name: "Vertex/mixed_supported_and_unsupported", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{ + {Type: schemas.ResponsesToolTypeWebSearch}, // allowed + {Type: schemas.ResponsesToolTypeFunction}, // allowed + {Type: schemas.ResponsesToolTypeCodeInterpreter}, // rejected + }, + expectErr: true, + errSubstr: "code_interpreter", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := anthropic.ValidateToolsForProvider(tt.tools, tt.provider) + if tt.expectErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + assert.Contains(t, err.Error(), string(tt.provider)) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestProviderWebSearchVersionSelection verifies that the correct web_search version +// is selected based on model and provider +func TestProviderWebSearchVersionSelection(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider schemas.ModelProvider + model string + expectedToolType string + }{ + // Anthropic 4.6 model → dynamic filtering version + { + name: "Anthropic/4.6_model_gets_dynamic_filtering", + provider: schemas.Anthropic, + model: "claude-opus-4-6", + expectedToolType: "web_search_20260209", + }, + { + name: "Anthropic/sonnet_4-6_gets_dynamic_filtering", + provider: schemas.Anthropic, + model: "claude-sonnet-4-6", + expectedToolType: "web_search_20260209", + }, + // Anthropic non-4.6 model → basic version + { + name: "Anthropic/4.5_model_gets_basic", + provider: schemas.Anthropic, + model: "claude-opus-4-5-20251101", + expectedToolType: "web_search_20250305", + }, + // Vertex 4.6 model → forced to basic (no dynamic filtering on Vertex) + { + name: "Vertex/4.6_model_forced_to_basic", + provider: schemas.Vertex, + model: "claude-opus-4-6", + expectedToolType: "web_search_20250305", + }, + { + name: "Vertex/sonnet_4-6_forced_to_basic", + provider: schemas.Vertex, + model: "claude-sonnet-4-6", + expectedToolType: "web_search_20250305", + }, + // Vertex non-4.6 model → basic version + { + name: "Vertex/4.5_model_gets_basic", + provider: schemas.Vertex, + model: "claude-sonnet-4-5-20250929", + expectedToolType: "web_search_20250305", + }, + // Azure 4.6 model → dynamic filtering (Azure supports it) + { + name: "Azure/4.6_model_gets_dynamic_filtering", + provider: schemas.Azure, + model: "claude-opus-4-6", + expectedToolType: "web_search_20260209", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := schemas.NewBifrostContext(nil, time.Time{}) + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: tt.provider, + Model: tt.model, + Input: []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("What is the weather?"), + }, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}, + }, + }, + }, + } + + result, err := anthropic.ToAnthropicResponsesRequest(ctx, bifrostReq) + require.NoError(t, err) + require.NotNil(t, result) + require.NotEmpty(t, result.Tools) + + // Find the web search tool in the result + found := false + for _, tool := range result.Tools { + if tool.Type != nil && tool.Name == "web_search" { + assert.Equal(t, tt.expectedToolType, string(*tool.Type), + "expected tool type %s but got %s for provider=%s model=%s", + tt.expectedToolType, string(*tool.Type), tt.provider, tt.model) + found = true + break + } + } + require.True(t, found, "web_search tool should be present in converted request") + }) + } +} + +// TestProviderWebFetchVersionSelection verifies that the correct web_fetch version +// is selected based on model and provider +func TestProviderWebFetchVersionSelection(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider schemas.ModelProvider + model string + expectedToolType string + }{ + { + name: "Anthropic/4.6_model_gets_latest", + provider: schemas.Anthropic, + model: "claude-opus-4-6", + expectedToolType: "web_fetch_20260309", + }, + { + name: "Anthropic/4.5_model_gets_basic", + provider: schemas.Anthropic, + model: "claude-opus-4-5-20251101", + expectedToolType: "web_fetch_20250910", + }, + { + name: "Azure/4.6_model_gets_latest", + provider: schemas.Azure, + model: "claude-sonnet-4-6", + expectedToolType: "web_fetch_20260309", + }, + { + name: "Azure/4.5_model_gets_basic", + provider: schemas.Azure, + model: "claude-sonnet-4-5-20250929", + expectedToolType: "web_fetch_20250910", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := schemas.NewBifrostContext(nil, time.Time{}) + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: tt.provider, + Model: tt.model, + Input: []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("Fetch https://example.com"), + }, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeWebFetch, + ResponsesToolWebFetch: &schemas.ResponsesToolWebFetch{}, + }, + }, + }, + } + + result, err := anthropic.ToAnthropicResponsesRequest(ctx, bifrostReq) + require.NoError(t, err) + require.NotNil(t, result) + require.NotEmpty(t, result.Tools) + + found := false + for _, tool := range result.Tools { + if tool.Type != nil && tool.Name == "web_fetch" { + assert.Equal(t, tt.expectedToolType, string(*tool.Type)) + found = true + break + } + } + require.True(t, found, "web_fetch tool should be present in converted request") + }) + } +} + +// TestProviderBetaHeaderInjection verifies that the correct beta headers +// are added (or omitted) based on provider +func TestProviderBetaHeaderInjection(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider schemas.ModelProvider + setupReq func() *anthropic.AnthropicMessageRequest + expectHeaders []string + unexpectHeaders []string + }{ + // ── Structured outputs header ── + { + name: "Anthropic/structured_outputs_header_added", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + strict := true + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "test", Strict: &strict}}, + } + }, + expectHeaders: []string{"structured-outputs-2025-11-13"}, + }, + { + name: "Vertex/structured_outputs_header_skipped", + provider: schemas.Vertex, + setupReq: func() *anthropic.AnthropicMessageRequest { + strict := true + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "test", Strict: &strict}}, + } + }, + unexpectHeaders: []string{"structured-outputs-2025-11-13"}, + }, + { + name: "Bedrock/structured_outputs_header_added", + provider: schemas.Bedrock, + setupReq: func() *anthropic.AnthropicMessageRequest { + strict := true + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "test", Strict: &strict}}, + } + }, + expectHeaders: []string{"structured-outputs-2025-11-13"}, + }, + + // ── MCP header ── + { + name: "Anthropic/mcp_header_added", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + MCPServers: []anthropic.AnthropicMCPServer{{URL: "http://example.com"}}, + } + }, + expectHeaders: []string{"mcp-client-2025-04-04"}, + }, + { + name: "Vertex/mcp_header_skipped", + provider: schemas.Vertex, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + MCPServers: []anthropic.AnthropicMCPServer{{URL: "http://example.com"}}, + } + }, + unexpectHeaders: []string{"mcp-client-2025-04-04"}, + }, + { + name: "Bedrock/mcp_header_skipped", + provider: schemas.Bedrock, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + MCPServers: []anthropic.AnthropicMCPServer{{URL: "http://example.com"}}, + } + }, + unexpectHeaders: []string{"mcp-client-2025-04-04"}, + }, + { + name: "Azure/mcp_header_added", + provider: schemas.Azure, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + MCPServers: []anthropic.AnthropicMCPServer{{URL: "http://example.com"}}, + } + }, + expectHeaders: []string{"mcp-client-2025-04-04"}, + }, + + // ── Compaction header (supported on all providers) ── + { + name: "Anthropic/compaction_header_added", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + ContextManagement: &anthropic.ContextManagement{ + Edits: []anthropic.ContextManagementEdit{{Type: anthropic.ContextManagementEditTypeCompact}}, + }, + } + }, + expectHeaders: []string{"compact-2026-01-12"}, + }, + { + name: "Vertex/compaction_header_added", + provider: schemas.Vertex, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + ContextManagement: &anthropic.ContextManagement{ + Edits: []anthropic.ContextManagementEdit{{Type: anthropic.ContextManagementEditTypeCompact}}, + }, + } + }, + expectHeaders: []string{"compact-2026-01-12"}, + }, + { + name: "Bedrock/compaction_header_added", + provider: schemas.Bedrock, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + ContextManagement: &anthropic.ContextManagement{ + Edits: []anthropic.ContextManagementEdit{{Type: anthropic.ContextManagementEditTypeCompact}}, + }, + } + }, + expectHeaders: []string{"compact-2026-01-12"}, + }, + + // ── Context editing header (supported on all providers) ── + { + name: "Anthropic/context_editing_header_added", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + ContextManagement: &anthropic.ContextManagement{ + Edits: []anthropic.ContextManagementEdit{{Type: anthropic.ContextManagementEditTypeClearToolUses}}, + }, + } + }, + expectHeaders: []string{"context-management-2025-06-27"}, + }, + { + name: "Vertex/context_editing_header_added", + provider: schemas.Vertex, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + ContextManagement: &anthropic.ContextManagement{ + Edits: []anthropic.ContextManagementEdit{{Type: anthropic.ContextManagementEditTypeClearToolUses}}, + }, + } + }, + expectHeaders: []string{"context-management-2025-06-27"}, + }, + + // ── Prompt caching scope header ── + { + name: "Anthropic/prompt_caching_scope_added", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + scope := "global" + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{ + {Name: "test", CacheControl: &schemas.CacheControl{Type: "ephemeral", Scope: &scope}}, + }, + } + }, + expectHeaders: []string{"prompt-caching-scope-2026-01-05"}, + }, + { + name: "Vertex/prompt_caching_scope_skipped", + provider: schemas.Vertex, + setupReq: func() *anthropic.AnthropicMessageRequest { + scope := "global" + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{ + {Name: "test", CacheControl: &schemas.CacheControl{Type: "ephemeral", Scope: &scope}}, + }, + } + }, + unexpectHeaders: []string{"prompt-caching-scope-2026-01-05"}, + }, + { + name: "Bedrock/prompt_caching_scope_skipped", + provider: schemas.Bedrock, + setupReq: func() *anthropic.AnthropicMessageRequest { + scope := "global" + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{ + {Name: "test", CacheControl: &schemas.CacheControl{Type: "ephemeral", Scope: &scope}}, + }, + } + }, + unexpectHeaders: []string{"prompt-caching-scope-2026-01-05"}, + }, + + // ── Computer use version-specific beta headers ── + { + name: "Anthropic/computer_20251124_gets_correct_beta", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + toolType := anthropic.AnthropicToolTypeComputer20251124 + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "computer", Type: &toolType}}, + } + }, + expectHeaders: []string{"computer-use-2025-11-24"}, + unexpectHeaders: []string{"computer-use-2025-01-24"}, + }, + { + name: "Anthropic/computer_20250124_gets_correct_beta", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + toolType := anthropic.AnthropicToolTypeComputer20250124 + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "computer", Type: &toolType}}, + } + }, + expectHeaders: []string{"computer-use-2025-01-24"}, + unexpectHeaders: []string{"computer-use-2025-11-24"}, + }, + { + name: "Vertex/computer_20251124_gets_correct_beta", + provider: schemas.Vertex, + setupReq: func() *anthropic.AnthropicMessageRequest { + toolType := anthropic.AnthropicToolTypeComputer20251124 + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "computer", Type: &toolType}}, + } + }, + expectHeaders: []string{"computer-use-2025-11-24"}, + }, + { + name: "Bedrock/computer_20250124_gets_correct_beta", + provider: schemas.Bedrock, + setupReq: func() *anthropic.AnthropicMessageRequest { + toolType := anthropic.AnthropicToolTypeComputer20250124 + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "computer", Type: &toolType}}, + } + }, + expectHeaders: []string{"computer-use-2025-01-24"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := schemas.NewBifrostContext(nil, time.Time{}) + req := tt.setupReq() + + anthropic.AddMissingBetaHeadersToContext(ctx, req, tt.provider) + + var headers []string + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + headers = extraHeaders["anthropic-beta"] + } + + for _, expected := range tt.expectHeaders { + found := false + for _, h := range headers { + if h == expected { + found = true + break + } + } + assert.True(t, found, "expected beta header %q for provider %s, got headers: %v", expected, tt.provider, headers) + } + + for _, unexpected := range tt.unexpectHeaders { + for _, h := range headers { + assert.NotEqual(t, unexpected, h, "unexpected beta header %q should NOT be present for provider %s", unexpected, tt.provider) + } + } + }) + } +} + +// TestProviderAnthropicRequestPipeline exercises the full Vertex/Bedrock/Anthropic/Azure +// request preparation pipeline end-to-end: validate tools → convert request → inject beta headers. +// This catches regressions in the provider-specific paths (e.g., missing tool version remapping +// or unsupported beta headers leaking through). +func TestProviderAnthropicRequestPipeline(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider schemas.ModelProvider + model string + tools []schemas.ResponsesTool + expectConversionErr bool + errSubstr string + expectedWebSearchType string // expected web_search tool type after conversion + expectedBetaHeaders []string + unexpectedBetaHeaders []string + }{ + // ── Vertex: web_search with filters → basic version, no dynamic headers ── + { + name: "Vertex/web_search_4.6_gets_basic_version_no_dynamic_headers", + provider: schemas.Vertex, + model: "claude-opus-4-6", + tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + UserLocation: &schemas.ResponsesToolWebSearchUserLocation{ + Type: schemas.Ptr("approximate"), + Country: schemas.Ptr("US"), + }, + }, + }, + }, + expectedWebSearchType: "web_search_20250305", // Vertex does NOT get dynamic filtering + expectedBetaHeaders: nil, // no beta headers for basic web search + unexpectedBetaHeaders: []string{"structured-outputs-2025-11-13", "mcp-client-2025-04-04", "prompt-caching-scope-2026-01-05"}, + }, + { + name: "Vertex/web_search_with_compaction_gets_compaction_header", + provider: schemas.Vertex, + model: "claude-sonnet-4-6", + tools: []schemas.ResponsesTool{ + {Type: schemas.ResponsesToolTypeWebSearch, ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}}, + }, + expectedWebSearchType: "web_search_20250305", + }, + // ── Vertex: web_fetch rejected ── + { + name: "Vertex/web_fetch_rejected_in_pipeline", + provider: schemas.Vertex, + tools: []schemas.ResponsesTool{ + {Type: schemas.ResponsesToolTypeWebFetch, ResponsesToolWebFetch: &schemas.ResponsesToolWebFetch{}}, + }, + expectConversionErr: true, + errSubstr: "web_fetch", + }, + // ── Anthropic: web_search 4.6 gets dynamic filtering ── + { + name: "Anthropic/web_search_4.6_gets_dynamic_version", + provider: schemas.Anthropic, + model: "claude-opus-4-6", + tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{ + Filters: &schemas.ResponsesToolWebSearchFilters{ + AllowedDomains: []string{"example.com"}, + }, + }, + }, + }, + expectedWebSearchType: "web_search_20260209", + }, + { + name: "Anthropic/web_search_4.5_gets_basic_version", + provider: schemas.Anthropic, + model: "claude-opus-4-5-20251101", + tools: []schemas.ResponsesTool{ + {Type: schemas.ResponsesToolTypeWebSearch, ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}}, + }, + expectedWebSearchType: "web_search_20250305", + }, + // ── Azure: web_search 4.6 gets dynamic filtering (same as Anthropic) ── + { + name: "Azure/web_search_4.6_gets_dynamic_version", + provider: schemas.Azure, + model: "claude-sonnet-4-6", + tools: []schemas.ResponsesTool{ + {Type: schemas.ResponsesToolTypeWebSearch, ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}}, + }, + expectedWebSearchType: "web_search_20260209", + }, + // ── Bedrock: web_search rejected ── + { + name: "Bedrock/web_search_rejected_in_pipeline", + provider: schemas.Bedrock, + tools: []schemas.ResponsesTool{ + {Type: schemas.ResponsesToolTypeWebSearch, ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}}, + }, + expectConversionErr: true, + errSubstr: "web_search", + }, + // ── Bedrock: computer_use with structured outputs → correct headers ── + { + name: "Bedrock/computer_use_with_structured_outputs_headers", + provider: schemas.Bedrock, + model: "claude-sonnet-4-6", + tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeComputerUsePreview, + ResponsesToolComputerUsePreview: &schemas.ResponsesToolComputerUsePreview{ + DisplayWidth: 1024, DisplayHeight: 768, + }, + }, + }, + expectedBetaHeaders: []string{"computer-use-2025-11-24"}, + unexpectedBetaHeaders: []string{"mcp-client-2025-04-04", "prompt-caching-scope-2026-01-05"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := schemas.NewBifrostContext(nil, time.Time{}) + model := tt.model + if model == "" { + model = "claude-sonnet-4-5" + } + + // Step 1: Validate tools for provider + if valErr := anthropic.ValidateToolsForProvider(tt.tools, tt.provider); valErr != nil { + if tt.expectConversionErr { + assert.Contains(t, valErr.Error(), tt.errSubstr) + return + } + t.Fatalf("unexpected validation error: %v", valErr) + } + + // Step 2: Convert bifrost request → anthropic request + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: tt.provider, + Model: model, + Input: []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("Test query"), + }, + Params: &schemas.ResponsesParameters{ + Tools: tt.tools, + }, + } + + result, err := anthropic.ToAnthropicResponsesRequest(ctx, bifrostReq) + if tt.expectConversionErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + return + } + require.NoError(t, err) + require.NotNil(t, result) + + // Step 3: Verify web_search tool type if expected + if tt.expectedWebSearchType != "" { + found := false + for _, tool := range result.Tools { + if tool.Name == "web_search" { + require.NotNil(t, tool.Type) + assert.Equal(t, tt.expectedWebSearchType, string(*tool.Type), + "wrong web_search type for provider=%s model=%s", tt.provider, model) + found = true + break + } + } + require.True(t, found, "web_search tool should be present in converted request") + } + + // Step 4: Run beta header injection + anthropic.AddMissingBetaHeadersToContext(ctx, result, tt.provider) + + // Step 5: Verify beta headers + var headers []string + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + headers = extraHeaders["anthropic-beta"] + } + + for _, expected := range tt.expectedBetaHeaders { + found := false + for _, h := range headers { + if h == expected { + found = true + break + } + } + assert.True(t, found, "expected beta header %q not found in %v for provider=%s", expected, headers, tt.provider) + } + + for _, unexpected := range tt.unexpectedBetaHeaders { + for _, h := range headers { + assert.NotEqual(t, unexpected, h, "unexpected beta header %q should NOT be present for provider=%s", unexpected, tt.provider) + } + } + }) + } +} + +// TestProviderFeatureMapCompleteness ensures every provider in the map has consistent settings +func TestProviderFeatureMapCompleteness(t *testing.T) { + t.Parallel() + + // Verify all four major providers are in the map + for _, provider := range []schemas.ModelProvider{schemas.Anthropic, schemas.Vertex, schemas.Bedrock, schemas.Azure} { + features, ok := anthropic.ProviderFeatures[provider] + assert.True(t, ok, "provider %s should be in ProviderFeatures map", provider) + + // Anthropic and Azure should support everything + if provider == schemas.Anthropic || provider == schemas.Azure { + assert.True(t, features.WebSearch, "%s should support WebSearch", provider) + assert.True(t, features.WebSearchDynamic, "%s should support WebSearchDynamic", provider) + assert.True(t, features.WebFetch, "%s should support WebFetch", provider) + assert.True(t, features.CodeExecution, "%s should support CodeExecution", provider) + assert.True(t, features.MCP, "%s should support MCP", provider) + assert.True(t, features.StructuredOutputs, "%s should support StructuredOutputs", provider) + assert.True(t, features.FilesAPI, "%s should support FilesAPI", provider) + } + + // Vertex specifics + if provider == schemas.Vertex { + assert.True(t, features.WebSearch, "Vertex should support basic WebSearch") + assert.False(t, features.WebSearchDynamic, "Vertex should NOT support WebSearchDynamic") + assert.False(t, features.WebFetch, "Vertex should NOT support WebFetch") + assert.False(t, features.CodeExecution, "Vertex should NOT support CodeExecution") + assert.False(t, features.MCP, "Vertex should NOT support MCP") + assert.False(t, features.StructuredOutputs, "Vertex should NOT support StructuredOutputs") + assert.True(t, features.Compaction, "Vertex should support Compaction") + assert.True(t, features.ContextEditing, "Vertex should support ContextEditing") + } + + // Bedrock specifics + if provider == schemas.Bedrock { + assert.False(t, features.WebSearch, "Bedrock should NOT support WebSearch") + assert.False(t, features.WebFetch, "Bedrock should NOT support WebFetch") + assert.False(t, features.CodeExecution, "Bedrock should NOT support CodeExecution") + assert.False(t, features.MCP, "Bedrock should NOT support MCP") + assert.True(t, features.StructuredOutputs, "Bedrock should support StructuredOutputs") + assert.True(t, features.Compaction, "Bedrock should support Compaction") + assert.True(t, features.ComputerUse, "Bedrock should support ComputerUse") + } + + // All providers should support client-side tools + assert.True(t, features.ComputerUse, "%s should support ComputerUse", provider) + assert.True(t, features.Bash, "%s should support Bash", provider) + assert.True(t, features.Memory, "%s should support Memory", provider) + assert.True(t, features.TextEditor, "%s should support TextEditor", provider) + assert.True(t, features.ToolSearch, "%s should support ToolSearch", provider) + } +} + +// TestComputerUseVersionAndBetaHeaderEndToEnd verifies the full pipeline: +// bifrost tool → anthropic tool version → correct beta header +func TestComputerUseVersionAndBetaHeaderEndToEnd(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider schemas.ModelProvider + model string + expectedToolType string + expectedBetaHeader string + }{ + { + name: "Anthropic/opus_4.6_gets_20251124_and_matching_beta", + provider: schemas.Anthropic, + model: "claude-opus-4-6", + expectedToolType: "computer_20251124", + expectedBetaHeader: "computer-use-2025-11-24", + }, + { + name: "Anthropic/sonnet_4.6_gets_20251124_and_matching_beta", + provider: schemas.Anthropic, + model: "claude-sonnet-4-6", + expectedToolType: "computer_20251124", + expectedBetaHeader: "computer-use-2025-11-24", + }, + { + name: "Anthropic/opus_4.5_gets_20251124_and_matching_beta", + provider: schemas.Anthropic, + model: "claude-opus-4-5-20251101", + expectedToolType: "computer_20251124", + expectedBetaHeader: "computer-use-2025-11-24", + }, + { + name: "Anthropic/sonnet_4.5_gets_20250124_and_matching_beta", + provider: schemas.Anthropic, + model: "claude-sonnet-4-5-20250929", + expectedToolType: "computer_20250124", + expectedBetaHeader: "computer-use-2025-01-24", + }, + { + name: "Anthropic/sonnet_4_gets_20250124_and_matching_beta", + provider: schemas.Anthropic, + model: "claude-sonnet-4-20250514", + expectedToolType: "computer_20250124", + expectedBetaHeader: "computer-use-2025-01-24", + }, + { + name: "Vertex/opus_4.6_gets_20251124_and_matching_beta", + provider: schemas.Vertex, + model: "claude-opus-4-6", + expectedToolType: "computer_20251124", + expectedBetaHeader: "computer-use-2025-11-24", + }, + { + name: "Bedrock/sonnet_4_gets_20250124_and_matching_beta", + provider: schemas.Bedrock, + model: "claude-sonnet-4-20250514", + expectedToolType: "computer_20250124", + expectedBetaHeader: "computer-use-2025-01-24", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := schemas.NewBifrostContext(nil, time.Time{}) + + // Step 1: Convert bifrost tool → anthropic tool (selects version based on model) + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: tt.provider, + Model: tt.model, + Input: []schemas.ResponsesMessage{ + CreateBasicResponsesMessage("Take a screenshot"), + }, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeComputerUsePreview, + ResponsesToolComputerUsePreview: &schemas.ResponsesToolComputerUsePreview{ + DisplayWidth: 1024, + DisplayHeight: 768, + }, + }, + }, + }, + } + + result, err := anthropic.ToAnthropicResponsesRequest(ctx, bifrostReq) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify correct tool version was selected + var computerTool *anthropic.AnthropicTool + for i, tool := range result.Tools { + if tool.Name == "computer" { + computerTool = &result.Tools[i] + break + } + } + require.NotNil(t, computerTool, "computer tool should be present") + require.NotNil(t, computerTool.Type) + assert.Equal(t, tt.expectedToolType, string(*computerTool.Type), + "wrong tool version for model=%s provider=%s", tt.model, tt.provider) + + // Step 2: Run beta header injection on the converted request + anthropic.AddMissingBetaHeadersToContext(ctx, result, tt.provider) + + // Step 3: Verify correct beta header was added + var headers []string + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + headers = extraHeaders["anthropic-beta"] + } + + found := false + for _, h := range headers { + if h == tt.expectedBetaHeader { + found = true + break + } + } + assert.True(t, found, "expected beta header %q not found in %v for model=%s provider=%s", + tt.expectedBetaHeader, headers, tt.model, tt.provider) + }) + } +} + +// TestRawBodyToolVersionRemapping verifies that when a raw request body contains +// a tool version unsupported by the target provider, it gets remapped automatically +func TestRawBodyToolVersionRemapping(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider schemas.ModelProvider + inputJSON string + expectedToolType string + expectErr bool + errSubstr string + }{ + // ── Vertex: web_search_20260209 → web_search_20250305 ── + { + name: "Vertex/web_search_20260209_remapped_to_20250305", + provider: schemas.Vertex, + inputJSON: `{ + "model": "claude-opus-4-6", + "max_tokens": 1024, + "tools": [{"type": "web_search_20260209", "name": "web_search"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + expectedToolType: "web_search_20250305", + }, + // ── Vertex: web_search_20250305 unchanged ── + { + name: "Vertex/web_search_20250305_unchanged", + provider: schemas.Vertex, + inputJSON: `{ + "model": "claude-opus-4-6", + "max_tokens": 1024, + "tools": [{"type": "web_search_20250305", "name": "web_search"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + expectedToolType: "web_search_20250305", + }, + // ── Vertex: web_fetch rejected (no remap possible) ── + { + name: "Vertex/web_fetch_rejected_in_raw_body", + provider: schemas.Vertex, + inputJSON: `{ + "model": "claude-opus-4-6", + "tools": [{"type": "web_fetch_20250910", "name": "web_fetch"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + expectErr: true, + errSubstr: "web_fetch_20250910", + }, + // ── Vertex: code_execution rejected ── + { + name: "Vertex/code_execution_rejected_in_raw_body", + provider: schemas.Vertex, + inputJSON: `{ + "model": "claude-opus-4-6", + "tools": [{"type": "code_execution_20250825", "name": "code_execution"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + expectErr: true, + errSubstr: "code_execution", + }, + // ── Bedrock: web_search rejected ── + { + name: "Bedrock/web_search_rejected_in_raw_body", + provider: schemas.Bedrock, + inputJSON: `{ + "model": "claude-opus-4-6", + "tools": [{"type": "web_search_20250305", "name": "web_search"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + expectErr: true, + errSubstr: "web_search_20250305", + }, + // ── Anthropic: no remapping needed ── + { + name: "Anthropic/web_search_20260209_unchanged", + provider: schemas.Anthropic, + inputJSON: `{ + "model": "claude-opus-4-6", + "tools": [{"type": "web_search_20260209", "name": "web_search"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + expectedToolType: "web_search_20260209", + }, + // ── No tools in body: no error ── + { + name: "Vertex/no_tools_no_error", + provider: schemas.Vertex, + inputJSON: `{ + "model": "claude-opus-4-6", + "messages": [{"role": "user", "content": "hello"}] + }`, + }, + // ── Vertex: bash tool unchanged (supported) ── + { + name: "Vertex/bash_tool_unchanged", + provider: schemas.Vertex, + inputJSON: `{ + "model": "claude-opus-4-6", + "tools": [{"type": "bash_20250124", "name": "bash"}], + "messages": [{"role": "user", "content": "hello"}] + }`, + expectedToolType: "bash_20250124", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := anthropic.RemapRawToolVersionsForProvider([]byte(tt.inputJSON), tt.provider) + + if tt.expectErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + return + } + + require.NoError(t, err) + + if tt.expectedToolType != "" { + // Extract the tool type from the result JSON + toolType := providerUtils.GetJSONField(result, "tools.0.type").String() + assert.Equal(t, tt.expectedToolType, toolType, + "expected tool type %s but got %s for provider %s", + tt.expectedToolType, toolType, tt.provider) + } + }) + } +} diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 23b1b21731..25943def32 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -436,7 +436,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k if convErr != nil { return nil, convErr } - addMissingBetaHeadersToContext(ctx, anthropicReq) + AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil }, provider.GetProviderKey()) @@ -520,7 +520,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont return nil, convErr } anthropicReq.Stream = schemas.Ptr(true) - addMissingBetaHeadersToContext(ctx, anthropicReq) + AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil }, provider.GetProviderKey()) diff --git a/core/providers/anthropic/chat.go b/core/providers/anthropic/chat.go index 153e187ee6..6d3c3c61fb 100644 --- a/core/providers/anthropic/chat.go +++ b/core/providers/anthropic/chat.go @@ -53,9 +53,9 @@ func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.Bif if cm, ok := cmVal.(*ContextManagement); ok && cm != nil { delete(anthropicReq.ExtraParams, "context_management") anthropicReq.ContextManagement = cm - } else if data, err := json.Marshal(cmVal); err == nil { + } else if data, err := providerUtils.MarshalSorted(cmVal); err == nil { var cm ContextManagement - if json.Unmarshal(data, &cm) == nil { + if sonic.Unmarshal(data, &cm) == nil { delete(anthropicReq.ExtraParams, "context_management") anthropicReq.ContextManagement = &cm } @@ -457,7 +457,7 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse(ctx *schemas.Bif // This is a structured output tool - convert to text content var jsonStr string if c.Input != nil { - if argBytes, err := sonic.Marshal(c.Input); err == nil { + if argBytes, err := providerUtils.MarshalSorted(c.Input); err == nil { jsonStr = string(argBytes) } else { jsonStr = fmt.Sprintf("%v", c.Input) @@ -475,7 +475,7 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse(ctx *schemas.Bif // Marshal the input to JSON string if c.Input != nil { - args, err := json.Marshal(c.Input) + args, err := providerUtils.MarshalSorted(c.Input) if err != nil { function.Arguments = fmt.Sprintf("%v", c.Input) } else { @@ -667,21 +667,24 @@ func ToAnthropicChatResponse(bifrostResp *schemas.BifrostChatResponse) *Anthropi // Add tool calls as tool_use content if choice.ChatNonStreamResponseChoice != nil && choice.Message != nil && choice.Message.ChatAssistantMessage != nil && choice.Message.ChatAssistantMessage.ToolCalls != nil { for _, toolCall := range choice.Message.ChatAssistantMessage.ToolCalls { - // Parse arguments JSON string back to map - var input map[string]interface{} + // Parse arguments JSON string to raw message + var inputRaw json.RawMessage if toolCall.Function.Arguments != "" { - if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { - input = map[string]interface{}{} + // Validate it's valid JSON, otherwise use empty object + if json.Valid([]byte(toolCall.Function.Arguments)) { + inputRaw = json.RawMessage(toolCall.Function.Arguments) + } else { + inputRaw = json.RawMessage("{}") } } else { - input = map[string]interface{}{} + inputRaw = json.RawMessage("{}") } content = append(content, AnthropicContentBlock{ Type: AnthropicContentBlockTypeToolUse, ID: toolCall.ID, Name: toolCall.Function.Name, - Input: input, + Input: inputRaw, }) } } @@ -1057,7 +1060,7 @@ func ToAnthropicChatStreamResponse(bifrostResp *schemas.BifrostChatResponse) str } // Marshal to JSON and format as SSE - jsonData, err := json.Marshal(streamResp) + jsonData, err := providerUtils.MarshalSorted(streamResp) if err != nil { return "" } @@ -1073,7 +1076,7 @@ func ToAnthropicChatStreamError(bifrostErr *schemas.BifrostError) string { return "" } // Marshal to JSON - jsonData, err := json.Marshal(errorResp) + jsonData, err := providerUtils.MarshalSorted(errorResp) if err != nil { return "" } diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index 48aa0fd1ce..9bcfb32aa5 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -9,7 +9,9 @@ import ( "sync" "time" + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" + "github.com/tidwall/gjson" providerUtils "github.com/maximhq/bifrost/core/providers/utils" ) @@ -909,7 +911,7 @@ func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, var action *schemas.ResponsesComputerToolCallAction if state.AccumulatedJSON != "" { - if err := json.Unmarshal([]byte(state.AccumulatedJSON), &inputMap); err == nil { + if err := sonic.Unmarshal([]byte(state.AccumulatedJSON), &inputMap); err == nil { action = convertAnthropicToResponsesComputerAction(inputMap) } } @@ -968,12 +970,9 @@ func (chunk *AnthropicStreamEvent) ToBifrostResponsesStream(ctx context.Context, var query string var queries []string if state.AccumulatedJSON != "" { - var inputMap map[string]interface{} - if err := json.Unmarshal([]byte(state.AccumulatedJSON), &inputMap); err == nil { - if q, ok := inputMap["query"].(string); ok { - query = q - queries = []string{q} - } + if q := providerUtils.GetJSONField([]byte(state.AccumulatedJSON), "query"); q.Exists() && q.Type == gjson.String { + query = q.Str + queries = []string{q.Str} } } @@ -1428,7 +1427,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp } // Always start with empty input for streaming compatibility - contentBlock.Input = map[string]interface{}{} + contentBlock.Input = json.RawMessage("{}") streamResp.ContentBlock = contentBlock } else if bifrostResp.Item != nil && @@ -1452,7 +1451,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp } // Start with empty input for streaming compatibility - contentBlock.Input = map[string]interface{}{} + contentBlock.Input = json.RawMessage("{}") streamResp.ContentBlock = contentBlock } else { @@ -1529,7 +1528,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp contentBlock.ID = bifrostResp.Item.ResponsesToolMessage.CallID contentBlock.Name = bifrostResp.Item.ResponsesToolMessage.Name // Always start with empty input for streaming compatibility - contentBlock.Input = map[string]interface{}{} + contentBlock.Input = json.RawMessage("{}") // Track WebSearch tools so we can skip their argument deltas if bifrostResp.Item.ResponsesToolMessage.Name != nil && @@ -1549,7 +1548,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp contentBlock.ServerName = &bifrostResp.Item.ResponsesToolMessage.ResponsesMCPToolCall.ServerLabel } // Always start with empty input for streaming compatibility - contentBlock.Input = map[string]interface{}{} + contentBlock.Input = json.RawMessage("{}") } } } @@ -1582,7 +1581,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp case schemas.ResponsesMessageTypeComputerCall: if bifrostResp.Item.ResponsesToolMessage.Action != nil && bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { actionInput := convertResponsesToAnthropicComputerAction(bifrostResp.Item.ResponsesToolMessage.Action.ResponsesComputerToolCallAction) - if jsonBytes, err := json.Marshal(actionInput); err == nil { + if jsonBytes, err := providerUtils.MarshalSorted(actionInput); err == nil { argumentsJSON = string(jsonBytes) shouldGenerateDeltas = true } @@ -1772,7 +1771,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp ) // Marshal the action to JSON string - if jsonBytes, err := json.Marshal(actionInput); err == nil { + if jsonBytes, err := providerUtils.MarshalSorted(actionInput); err == nil { jsonStr := string(jsonBytes) streamResp.Delta = &AnthropicStreamDelta{ Type: AnthropicStreamDeltaTypeInputJSON, @@ -1798,7 +1797,7 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp inputMap := map[string]interface{}{ "query": *bifrostResp.Item.ResponsesToolMessage.Action.ResponsesWebSearchToolCallAction.Query, } - if jsonBytes, err := json.Marshal(inputMap); err == nil { + if jsonBytes, err := providerUtils.MarshalSorted(inputMap); err == nil { queryJSON = string(jsonBytes) } } @@ -2330,9 +2329,9 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema anthropicReq.CacheControl = &v parsed = true default: - if data, err := json.Marshal(v); err == nil { + if data, err := providerUtils.MarshalSorted(v); err == nil { var cc schemas.CacheControl - if json.Unmarshal(data, &cc) == nil { + if sonic.Unmarshal(data, &cc) == nil { anthropicReq.CacheControl = &cc parsed = true } @@ -2359,9 +2358,9 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema if cm, ok := cmVal.(*ContextManagement); ok && cm != nil { delete(anthropicReq.ExtraParams, "context_management") anthropicReq.ContextManagement = cm - } else if data, err := json.Marshal(cmVal); err == nil { + } else if data, err := providerUtils.MarshalSorted(cmVal); err == nil { var cm ContextManagement - if json.Unmarshal(data, &cm) == nil { + if sonic.Unmarshal(data, &cm) == nil { delete(anthropicReq.ExtraParams, "context_management") anthropicReq.ContextManagement = &cm } @@ -2382,7 +2381,7 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema } continue // Skip converting MCP tools to anthropicTools since they're handled separately } - anthropicTool := convertBifrostToolToAnthropic(bifrostReq.Model, &tool) + anthropicTool := convertBifrostToolToAnthropic(bifrostReq.Model, &tool, bifrostReq.Provider) if anthropicTool != nil { anthropicTools = append(anthropicTools, *anthropicTool) } @@ -3077,8 +3076,11 @@ func ConvertBifrostMessagesToAnthropicMessages(ctx *schemas.BifrostContext, bifr } if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Action != nil && msg.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction != nil { - serverToolUseBlock.Input = map[string]interface{}{ + inputBytes, err := providerUtils.MarshalSorted(map[string]interface{}{ "url": msg.ResponsesToolMessage.Action.ResponsesWebFetchToolCallAction.URL, + }) + if err == nil { + serverToolUseBlock.Input = json.RawMessage(inputBytes) } } pendingToolCalls = append(pendingToolCalls, serverToolUseBlock) @@ -3462,7 +3464,8 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant if toolBlock.Name != nil && *toolBlock.Name == string(AnthropicToolNameComputer) { bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeComputerCall) bifrostMsg.ResponsesToolMessage.Name = nil - if inputMap, ok := toolBlock.Input.(map[string]interface{}); ok { + var inputMap map[string]interface{} + if err := sonic.Unmarshal(toolBlock.Input, &inputMap); err == nil { bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ ResponsesComputerToolCallAction: convertAnthropicToResponsesComputerAction(inputMap), } @@ -3470,31 +3473,30 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant } else if toolBlock.Name != nil && *toolBlock.Name == string(AnthropicToolNameWebSearch) { bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeWebSearchCall) bifrostMsg.ResponsesToolMessage.Name = nil - if inputMap, ok := toolBlock.Input.(map[string]interface{}); ok { - if query, ok := inputMap["query"].(string); ok { - bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ - Type: "search", - Query: schemas.Ptr(query), - Queries: []string{query}, - }, - } + if q := providerUtils.GetJSONField(toolBlock.Input, "query"); q.Exists() && q.Type == gjson.String { + query := q.Str + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + Query: schemas.Ptr(query), + Queries: []string{query}, + }, } } } else if toolBlock.Name != nil && *toolBlock.Name == string(AnthropicToolNameWebFetch) { bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeWebFetchCall) bifrostMsg.ResponsesToolMessage.Name = nil - if inputMap, ok := toolBlock.Input.(map[string]interface{}); ok { - if url, ok := inputMap["url"].(string); ok { - bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesWebFetchToolCallAction: &schemas.ResponsesWebFetchToolCallAction{ - URL: url, - }, - } + if u := providerUtils.GetJSONField(toolBlock.Input, "url"); u.Exists() && u.Type == gjson.String { + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebFetchToolCallAction: &schemas.ResponsesWebFetchToolCallAction{ + URL: u.Str, + }, } } } else { - bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(schemas.JsonifyInput(toolBlock.Input)) + if len(toolBlock.Input) > 0 { + bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(string(toolBlock.Input)) + } } bifrostMessages = append(bifrostMessages, bifrostMsg) @@ -3651,7 +3653,7 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex // This is a structured output tool - convert to text message var jsonStr string if block.Input != nil { - jsonStr = schemas.JsonifyInput(block.Input) + jsonStr = string(block.Input) } else { jsonStr = "{}" } @@ -3697,13 +3699,14 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex if block.Name != nil && *block.Name == string(AnthropicToolNameComputer) { bifrostMsg.Type = schemas.Ptr(schemas.ResponsesMessageTypeComputerCall) bifrostMsg.ResponsesToolMessage.Name = nil - if inputMap, ok := block.Input.(map[string]interface{}); ok { + var inputMap map[string]interface{} + if err := sonic.Unmarshal(block.Input, &inputMap); err == nil { bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ ResponsesComputerToolCallAction: convertAnthropicToResponsesComputerAction(inputMap), } } - } else { - bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(schemas.JsonifyInput(block.Input)) + } else if len(block.Input) > 0 { + bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(string(block.Input)) } bifrostMessages = append(bifrostMessages, bifrostMsg) } @@ -3772,15 +3775,14 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex // Extract query from input if block.Input != nil { - if inputMap, ok := block.Input.(map[string]interface{}); ok { - if query, ok := inputMap["query"].(string); ok { - bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ - Type: "search", - Query: schemas.Ptr(query), - Queries: []string{query}, // Anthropic uses single query - }, - } + if q := providerUtils.GetJSONField(block.Input, "query"); q.Exists() && q.Type == gjson.String { + query := q.Str + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebSearchToolCallAction: &schemas.ResponsesWebSearchToolCallAction{ + Type: "search", + Query: schemas.Ptr(query), + Queries: []string{query}, // Anthropic uses single query + }, } } } @@ -3797,13 +3799,11 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex } if block.Input != nil { - if inputMap, ok := block.Input.(map[string]interface{}); ok { - if url, ok := inputMap["url"].(string); ok { - bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ - ResponsesWebFetchToolCallAction: &schemas.ResponsesWebFetchToolCallAction{ - URL: url, - }, - } + if u := providerUtils.GetJSONField(block.Input, "url"); u.Exists() && u.Type == gjson.String { + bifrostMsg.ResponsesToolMessage.Action = &schemas.ResponsesToolMessageActionStruct{ + ResponsesWebFetchToolCallAction: &schemas.ResponsesWebFetchToolCallAction{ + URL: u.Str, + }, } } } @@ -3843,10 +3843,12 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex Type: schemas.Ptr(schemas.ResponsesMessageTypeMCPCall), ID: block.ID, ResponsesToolMessage: &schemas.ResponsesToolMessage{ - Name: block.Name, - Arguments: schemas.Ptr(schemas.JsonifyInput(block.Input)), + Name: block.Name, }, } + if len(block.Input) > 0 { + bifrostMsg.ResponsesToolMessage.Arguments = schemas.Ptr(string(block.Input)) + } if block.ServerName != nil { bifrostMsg.ResponsesToolMessage.ResponsesMCPToolCall = &schemas.ResponsesMCPToolCall{ ServerLabel: *block.ServerName, @@ -4211,7 +4213,10 @@ func convertBifrostComputerCallToAnthropicToolUse(msg *schemas.ResponsesMessage) } if msg.ResponsesToolMessage.Action != nil && msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil { - toolUseBlock.Input = convertResponsesToAnthropicComputerAction(msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction) + inputMap := convertResponsesToAnthropicComputerAction(msg.ResponsesToolMessage.Action.ResponsesComputerToolCallAction) + if inputBytes, err := providerUtils.MarshalSorted(inputMap); err == nil { + toolUseBlock.Input = json.RawMessage(inputBytes) + } } return &toolUseBlock @@ -4313,10 +4318,12 @@ func convertBifrostWebSearchCallToAnthropicBlocks(msg *schemas.ResponsesMessage) // Extract the query from the action if action.Query != nil { - input := map[string]interface{}{ + inputBytes, err := providerUtils.MarshalSorted(map[string]interface{}{ "query": *action.Query, + }) + if err == nil { + serverToolUseBlock.Input = json.RawMessage(inputBytes) } - serverToolUseBlock.Input = input } blocks = append(blocks, serverToolUseBlock) @@ -4661,7 +4668,7 @@ func convertToolOutputToAnthropicContent(output *schemas.ResponsesToolMessageOut } // Helper function to convert Tool back to AnthropicTool -func convertBifrostToolToAnthropic(model string, tool *schemas.ResponsesTool) *AnthropicTool { +func convertBifrostToolToAnthropic(model string, tool *schemas.ResponsesTool, provider schemas.ModelProvider) *AnthropicTool { if tool == nil { return nil } @@ -4687,7 +4694,10 @@ func convertBifrostToolToAnthropic(model string, tool *schemas.ResponsesTool) *A } case schemas.ResponsesToolTypeWebSearch: webSearchType := AnthropicToolTypeWebSearch20250305 - if strings.Contains(model, "4.6") || strings.Contains(model, "4-6") { + // Dynamic filtering (web_search_20260209) only available on Anthropic + Azure + features, ok := ProviderFeatures[provider] + if ok && features.WebSearchDynamic && + (strings.Contains(model, "4.6") || strings.Contains(model, "4-6")) { webSearchType = AnthropicToolTypeWebSearch20260209 } anthropicTool := &AnthropicTool{ @@ -4716,7 +4726,10 @@ func convertBifrostToolToAnthropic(model string, tool *schemas.ResponsesTool) *A return anthropicTool case schemas.ResponsesToolTypeWebFetch: webFetchType := AnthropicToolTypeWebFetch20250910 - if strings.Contains(model, "4.6") || strings.Contains(model, "4-6") { + // Dynamic filtering versions only available on Anthropic + Azure + features, ok := ProviderFeatures[provider] + if ok && features.WebSearchDynamic && + (strings.Contains(model, "4.6") || strings.Contains(model, "4-6")) { webFetchType = AnthropicToolTypeWebFetch20260309 } anthropicTool := &AnthropicTool{ diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index f335d3e712..35c15ba61c 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -7,6 +7,7 @@ import ( "time" "github.com/bytedance/sonic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -31,6 +32,12 @@ const ( // AnthropicContextManagementBetaHeader is required for context management. AnthropicContextManagementBetaHeader = "context-management-2025-06-27" + // AnthropicComputerUseBetaHeader is required for computer use (version-specific). + // computer_20251124 (Opus 4.6, Sonnet 4.6, Opus 4.5) uses the newer beta header. + AnthropicComputerUseBetaHeader20251124 = "computer-use-2025-11-24" + // computer_20250124 (all other supported models) uses the older beta header. + AnthropicComputerUseBetaHeader20250124 = "computer-use-2025-01-24" + // Prefixes for Vertex-unsupported beta headers (version-bump proof). // Use these with strings.HasPrefix when filtering headers for Vertex AI, // so that future date bumps (e.g. structured-outputs-2025-12-15) are still matched. @@ -40,6 +47,54 @@ const ( AnthropicMCPClientBetaHeaderPrefix = "mcp-client-" ) +// ProviderFeatureSupport defines which Anthropic features a given provider supports. +// Source: https://docs.anthropic.com/en/build-with-claude/overview (March 2026) +type ProviderFeatureSupport struct { + WebSearch bool // web_search server tool + WebSearchDynamic bool // web_search_20260209 (dynamic filtering, requires code_execution) + WebFetch bool // web_fetch server tool + CodeExecution bool // code_execution server tool + ComputerUse bool // computer_use client tool + Bash bool // bash client tool + Memory bool // memory client tool + TextEditor bool // text_editor client tool + ToolSearch bool // tool_search server tool + MCP bool // MCP connector + AdvancedToolUse bool // advanced-tool-use (defer_loading, input_examples, allowed_callers) + StructuredOutputs bool // strict tool validation and output_format + PromptCachingScope bool // prompt caching scope + Compaction bool // server-side context compaction + ContextEditing bool // context editing (clear_tool_uses, clear_thinking) + FilesAPI bool // Files API + FileSearch bool // file_search server tool (OpenAI-only) + ImageGeneration bool // image_generation server tool (OpenAI-only) +} + +// ProviderFeatures maps each provider to its supported Anthropic features. +var ProviderFeatures = map[schemas.ModelProvider]ProviderFeatureSupport{ + schemas.Anthropic: { + WebSearch: true, WebSearchDynamic: true, WebFetch: true, CodeExecution: true, + ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, + MCP: true, AdvancedToolUse: true, StructuredOutputs: true, PromptCachingScope: true, + Compaction: true, ContextEditing: true, FilesAPI: true, + }, + schemas.Vertex: { + WebSearch: true, // only web_search_20250305 (basic), NOT dynamic filtering + ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, + Compaction: true, ContextEditing: true, + }, + schemas.Bedrock: { + ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, + StructuredOutputs: true, Compaction: true, ContextEditing: true, + }, + schemas.Azure: { + WebSearch: true, WebSearchDynamic: true, WebFetch: true, CodeExecution: true, + ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, + MCP: true, AdvancedToolUse: true, StructuredOutputs: true, PromptCachingScope: true, + Compaction: true, ContextEditing: true, FilesAPI: true, + }, +} + // ==================== REQUEST TYPES ==================== // AnthropicTextRequest represents an Anthropic text completion request @@ -71,7 +126,7 @@ func (req *AnthropicTextRequest) IsStreamingRequested() bool { // AnthropicOutputConfig represents the GA structured outputs config (output_config.format) // and the effort parameter (output_config.effort) for controlling token spending. type AnthropicOutputConfig struct { - Format interface{} `json:"format,omitempty"` + Format json.RawMessage `json:"format,omitempty"` Effort *string `json:"effort,omitempty"` // "low", "medium", "high", "max" (Opus 4.5+) } @@ -92,7 +147,7 @@ type AnthropicMessageRequest struct { ToolChoice *AnthropicToolChoice `json:"tool_choice,omitempty"` MCPServers []AnthropicMCPServer `json:"mcp_servers,omitempty"` // This feature requires the beta header: "anthropic-beta": "mcp-client-2025-04-04" Thinking *AnthropicThinking `json:"thinking,omitempty"` - OutputFormat interface{} `json:"output_format,omitempty"` // Beta: requires header "anthropic-beta": "structured-outputs-2025-11-13" + OutputFormat json.RawMessage `json:"output_format,omitempty"` // Beta: requires header "anthropic-beta": "structured-outputs-2025-11-13" (json.RawMessage preserves key ordering) OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"` // GA: structured outputs without beta header ServiceTier *string `json:"service_tier,omitempty"` // "auto" or "standard_only" InferenceGeo *string `json:"inference_geo,omitempty"` // the geographic region for inference processing. If not specified, the workspace's default_inference_geo is used. @@ -154,12 +209,12 @@ func (tv CompactManagementEditTypeAndValue) MarshalJSON() ([]byte, error) { } if tv.TypeAndValueString != nil { - return sonic.Marshal(*tv.TypeAndValueString) + return providerUtils.MarshalSorted(*tv.TypeAndValueString) } if tv.TypeAndValueObject != nil { - return sonic.Marshal(tv.TypeAndValueObject) + return providerUtils.MarshalSorted(tv.TypeAndValueObject) } - return sonic.Marshal(nil) + return providerUtils.MarshalSorted(nil) } // UnmarshalJSON implements custom JSON unmarshalling for CompactManagementEditTypeAndValue. @@ -206,12 +261,12 @@ func (ct ClearToolInputs) MarshalJSON() ([]byte, error) { } if ct.ClearToolInputsBoolean != nil { - return sonic.Marshal(*ct.ClearToolInputsBoolean) + return providerUtils.MarshalSorted(*ct.ClearToolInputsBoolean) } if ct.ClearToolInputsArray != nil { - return sonic.Marshal(ct.ClearToolInputsArray) + return providerUtils.MarshalSorted(ct.ClearToolInputsArray) } - return sonic.Marshal(nil) + return providerUtils.MarshalSorted(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ClearToolInputs. @@ -257,13 +312,13 @@ func (edit ContextManagementEdit) MarshalJSON() ([]byte, error) { switch edit.Type { case ContextManagementEditTypeCompact: if edit.CompactManagementEditConfig == nil { - return sonic.Marshal(struct { + return providerUtils.MarshalSorted(struct { Type ContextManagementEditType `json:"type"` }{ Type: edit.Type, }) } - return sonic.Marshal(struct { + return providerUtils.MarshalSorted(struct { Type ContextManagementEditType `json:"type"` *CompactManagementEditConfig }{ @@ -274,7 +329,7 @@ func (edit ContextManagementEdit) MarshalJSON() ([]byte, error) { if edit.CompactManagementEditClearThinking == nil { return nil, fmt.Errorf("compact management edit clear thinking is nil for type clear_thinking_20251015") } - return sonic.Marshal(struct { + return providerUtils.MarshalSorted(struct { Type ContextManagementEditType `json:"type"` *CompactManagementEditClearThinking }{ @@ -285,7 +340,7 @@ func (edit ContextManagementEdit) MarshalJSON() ([]byte, error) { if edit.CompactManagementEditClearToolUses == nil { return nil, fmt.Errorf("compact management edit clear tool uses is nil for type clear_tool_uses_20250919") } - return sonic.Marshal(struct { + return providerUtils.MarshalSorted(struct { Type ContextManagementEditType `json:"type"` *CompactManagementEditClearToolUses }{ @@ -417,6 +472,20 @@ func (req *AnthropicMessageRequest) UnmarshalJSON(data []byte) error { } } + // Compact known json.RawMessage fields for deterministic cache keys + if len(req.OutputFormat) > 0 { + var buf bytes.Buffer + if err := json.Compact(&buf, req.OutputFormat); err == nil { + req.OutputFormat = json.RawMessage(buf.Bytes()) + } + } + if req.OutputConfig != nil && len(req.OutputConfig.Format) > 0 { + var buf bytes.Buffer + if err := json.Compact(&buf, req.OutputConfig.Format); err == nil { + req.OutputConfig.Format = json.RawMessage(buf.Bytes()) + } + } + return nil } @@ -477,10 +546,10 @@ func (req *AnthropicMessageRequest) MarshalJSON() ([]byte, error) { reqCopy.Messages = messagesCopy } - return sonic.Marshal((*Alias)(&reqCopy)) + return providerUtils.MarshalSorted((*Alias)(&reqCopy)) } - return sonic.Marshal((*Alias)(req)) + return providerUtils.MarshalSorted((*Alias)(req)) } // stripScopeFromContent strips scope from all cache control blocks in content @@ -540,10 +609,10 @@ func (mc AnthropicContent) MarshalJSON() ([]byte, error) { } if mc.ContentStr != nil { - return sonic.Marshal(*mc.ContentStr) + return providerUtils.MarshalSorted(*mc.ContentStr) } if mc.ContentBlocks != nil { - return sonic.Marshal(mc.ContentBlocks) + return providerUtils.MarshalSorted(mc.ContentBlocks) } // If both are nil, return empty array instead of null. // Anthropic's API requires content to be an array, not null. @@ -608,7 +677,7 @@ type AnthropicContentBlock struct { ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content ID *string `json:"id,omitempty"` // For tool_use content Name *string `json:"name,omitempty"` // For tool_use content - Input any `json:"input,omitempty"` // For tool_use content + Input json.RawMessage `json:"input,omitempty"` // For tool_use content (json.RawMessage preserves key ordering for prompt caching) ServerName *string `json:"server_name,omitempty"` // For mcp_tool_use content Content *AnthropicContent `json:"content,omitempty"` // For tool_result content IsError *bool `json:"is_error,omitempty"` // For tool_result content, indicates error state @@ -696,12 +765,12 @@ func (ac *AnthropicCitations) MarshalJSON() ([]byte, error) { } if ac.Config != nil { - return sonic.Marshal(ac.Config) + return providerUtils.MarshalSorted(ac.Config) } if ac.TextCitations != nil { - return sonic.Marshal(ac.TextCitations) + return providerUtils.MarshalSorted(ac.TextCitations) } - return sonic.Marshal(nil) + return providerUtils.MarshalSorted(nil) } // UnmarshalJSON implements the json.Unmarshaler interface @@ -812,7 +881,7 @@ type AnthropicToolWebFetch struct { // AnthropicToolInputExample represents an input example for a tool (beta feature) type AnthropicToolInputExample struct { - Input any `json:"input"` + Input json.RawMessage `json:"input"` Description *string `json:"description,omitempty"` } diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index ddabfbe5d8..32553ca206 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -6,12 +6,70 @@ import ( "fmt" "strings" + "github.com/bytedance/sonic" "github.com/valyala/fasthttp" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) +// ValidateToolsForProvider checks if all tools in the request are supported by the given provider. +// Returns an error for the first unsupported tool found. +func ValidateToolsForProvider(tools []schemas.ResponsesTool, provider schemas.ModelProvider) error { + features, ok := ProviderFeatures[provider] + if !ok { + // Unknown provider — allow all tools (safe default for custom providers) + return nil + } + + for _, tool := range tools { + switch tool.Type { + case schemas.ResponsesToolTypeWebSearch, schemas.ResponsesToolTypeWebSearchPreview: + if !features.WebSearch { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeWebFetch: + if !features.WebFetch { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeCodeInterpreter: + if !features.CodeExecution { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeComputerUsePreview: + if !features.ComputerUse { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeMCP: + if !features.MCP { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeLocalShell: + if !features.Bash { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeMemory: + if !features.Memory { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeToolSearch: + if !features.ToolSearch { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeFileSearch: + if !features.FileSearch { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + case schemas.ResponsesToolTypeImageGeneration: + if !features.ImageGeneration { + return fmt.Errorf("tool type '%s' is not supported by provider '%s'", tool.Type, provider) + } + // ResponsesToolTypeFunction, ResponsesToolTypeCustom, etc. are always allowed + } + } + return nil +} + var ( // Maps provider-specific finish reasons to Bifrost format anthropicFinishReasonToBifrost = map[AnthropicStopReason]string{ @@ -131,7 +189,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi if reqBody == nil { return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) } - addMissingBetaHeadersToContext(ctx, reqBody) + AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Anthropic) if isStreaming { reqBody.Stream = schemas.Ptr(true) } @@ -170,15 +228,32 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi return jsonBody, nil } -// addMissingBetaHeadersToContext analyzes the Anthropic request and adds missing beta headers to the context -func addMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicMessageRequest) error { +// AddMissingBetaHeadersToContext analyzes the Anthropic request and adds missing beta headers to the context. +// The provider parameter controls which headers are included — unsupported headers for the given provider are skipped. +func AddMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicMessageRequest, provider schemas.ModelProvider) error { + features, hasProvider := ProviderFeatures[provider] headers := []string{} hasCachingScope := false if req.Tools != nil { for _, tool := range req.Tools { + // Check for version-specific beta headers based on tool type + if tool.Type != nil { + switch *tool.Type { + case AnthropicToolTypeComputer20251124: + if !hasProvider || features.ComputerUse { + headers = appendUniqueHeader(headers, AnthropicComputerUseBetaHeader20251124) + } + case AnthropicToolTypeComputer20250124: + if !hasProvider || features.ComputerUse { + headers = appendUniqueHeader(headers, AnthropicComputerUseBetaHeader20250124) + } + } + } // Check for strict (structured-outputs) if tool.Strict != nil && *tool.Strict { - headers = appendUniqueHeader(headers, AnthropicStructuredOutputsBetaHeader) + if !hasProvider || features.StructuredOutputs { + headers = appendUniqueHeader(headers, AnthropicStructuredOutputsBetaHeader) + } } // Check for advanced-tool-use features if tool.DeferLoading != nil && *tool.DeferLoading { @@ -192,8 +267,10 @@ func addMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicM } // Check for cache control with scope if !hasCachingScope && tool.CacheControl != nil && tool.CacheControl.Scope != nil { - headers = appendUniqueHeader(headers, AnthropicPromptCachingScopeBetaHeader) - hasCachingScope = true + if !hasProvider || features.PromptCachingScope { + headers = appendUniqueHeader(headers, AnthropicPromptCachingScopeBetaHeader) + hasCachingScope = true + } } } } @@ -201,27 +278,37 @@ func addMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicM if req.ContextManagement != nil { for _, edit := range req.ContextManagement.Edits { if edit.Type == ContextManagementEditTypeCompact { - headers = appendUniqueHeader(headers, AnthropicCompactionBetaHeader) + if !hasProvider || features.Compaction { + headers = appendUniqueHeader(headers, AnthropicCompactionBetaHeader) + } } if edit.Type == ContextManagementEditTypeClearToolUses || edit.Type == ContextManagementEditTypeClearThinking { - headers = appendUniqueHeader(headers, AnthropicContextManagementBetaHeader) + if !hasProvider || features.ContextEditing { + headers = appendUniqueHeader(headers, AnthropicContextManagementBetaHeader) + } } } } // Check for MCP servers if len(req.MCPServers) > 0 { - headers = appendUniqueHeader(headers, AnthropicMCPClientBetaHeader) + if !hasProvider || features.MCP { + headers = appendUniqueHeader(headers, AnthropicMCPClientBetaHeader) + } } // Check for output format (structured outputs) if req.OutputFormat != nil { - headers = appendUniqueHeader(headers, AnthropicStructuredOutputsBetaHeader) + if !hasProvider || features.StructuredOutputs { + headers = appendUniqueHeader(headers, AnthropicStructuredOutputsBetaHeader) + } } // Check for cache control with scope in system message (only if not already found) if !hasCachingScope && req.System != nil && req.System.ContentBlocks != nil { for _, block := range req.System.ContentBlocks { if block.CacheControl != nil && block.CacheControl.Scope != nil { - headers = appendUniqueHeader(headers, AnthropicPromptCachingScopeBetaHeader) - hasCachingScope = true + if !hasProvider || features.PromptCachingScope { + headers = appendUniqueHeader(headers, AnthropicPromptCachingScopeBetaHeader) + hasCachingScope = true + } break } } @@ -232,8 +319,10 @@ func addMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicM if message.Content.ContentBlocks != nil { for _, block := range message.Content.ContentBlocks { if block.CacheControl != nil && block.CacheControl.Scope != nil { - headers = appendUniqueHeader(headers, AnthropicPromptCachingScopeBetaHeader) - hasCachingScope = true + if !hasProvider || features.PromptCachingScope { + headers = appendUniqueHeader(headers, AnthropicPromptCachingScopeBetaHeader) + hasCachingScope = true + } break } } @@ -263,6 +352,148 @@ func addMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicM return nil } +// ToolVersionRemap defines a mapping from an unsupported tool version to a supported one. +type ToolVersionRemap struct { + From string + To string +} + +// providerToolVersionRemaps defines version downgrades per provider. +// When a raw request contains a tool type not supported by the target provider, +// it gets remapped to the supported version. +var providerToolVersionRemaps = map[schemas.ModelProvider][]ToolVersionRemap{ + schemas.Vertex: { + // Vertex only supports basic web search, not dynamic filtering + {From: string(AnthropicToolTypeWebSearch20260209), To: string(AnthropicToolTypeWebSearch20250305)}, + // Vertex does not support web fetch at all — no remap, these should error + // Vertex does not support code execution — no remap, these should error + }, + // Bedrock does not support web search, web fetch, or code execution at all — no remaps + // Anthropic and Azure support all versions — no remaps needed +} + +// unsupportedRawToolTypes lists tool type prefixes that should be rejected per provider +// when found in raw request bodies (no remap possible, the feature itself is unsupported). +var unsupportedRawToolTypes = map[schemas.ModelProvider][]string{ + schemas.Vertex: { + "web_fetch_", // No web fetch support on Vertex + "code_execution", // No code execution on Vertex + }, + schemas.Bedrock: { + "web_search_", // No web search on Bedrock + "web_fetch_", // No web fetch on Bedrock + "code_execution", // No code execution on Bedrock + }, +} + +// RemapRawToolVersionsForProvider inspects tools in a raw JSON body and remaps +// unsupported tool versions to supported ones for the target provider. +// Returns an error if a tool type is fundamentally unsupported (no remap possible). +func RemapRawToolVersionsForProvider(jsonBody []byte, provider schemas.ModelProvider) ([]byte, error) { + toolsResult := providerUtils.GetJSONField(jsonBody, "tools") + if !toolsResult.Exists() || !toolsResult.IsArray() { + return jsonBody, nil + } + + var err error + tools := toolsResult.Array() + + // Check for unsupported types first + if prefixes, ok := unsupportedRawToolTypes[provider]; ok { + for _, tool := range tools { + toolType := tool.Get("type").String() + for _, prefix := range prefixes { + if strings.HasPrefix(toolType, prefix) { + return nil, fmt.Errorf("tool type '%s' is not supported by provider '%s'", toolType, provider) + } + } + } + } + + // Apply version remaps + remaps, ok := providerToolVersionRemaps[provider] + if !ok { + return jsonBody, nil + } + + for i, tool := range tools { + toolType := tool.Get("type").String() + for _, remap := range remaps { + if toolType == remap.From { + path := fmt.Sprintf("tools.%d.type", i) + jsonBody, err = providerUtils.SetJSONField(jsonBody, path, remap.To) + if err != nil { + return nil, fmt.Errorf("failed to remap tool type: %w", err) + } + break + } + } + } + + return jsonBody, nil +} + +// FilterBetaHeadersForProvider validates that all beta headers are supported by the given provider. +// Returns an error if a known beta header is not supported by the provider. +// Unknown headers (not matched by any known prefix) are forwarded as-is for forward compatibility. +func FilterBetaHeadersForProvider(headers []string, provider schemas.ModelProvider) ([]string, error) { + features, hasProvider := ProviderFeatures[provider] + if !hasProvider { + // Unknown provider — allow all headers (safe default for custom providers) + return headers, nil + } + + filtered := make([]string, 0, len(headers)) + for _, h := range headers { + switch { + case strings.HasPrefix(h, "computer-use-"): + if !features.ComputerUse { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + case strings.HasPrefix(h, AnthropicStructuredOutputsBetaHeaderPrefix): + if !features.StructuredOutputs { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + case strings.HasPrefix(h, AnthropicMCPClientBetaHeaderPrefix): + if !features.MCP { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + case strings.HasPrefix(h, AnthropicPromptCachingScopeBetaHeaderPrefix): + if !features.PromptCachingScope { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + case strings.HasPrefix(h, "compact-"): + if !features.Compaction { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + case strings.HasPrefix(h, "context-management-"): + if !features.ContextEditing { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + case strings.HasPrefix(h, "files-api-"): + if !features.FilesAPI { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + case strings.HasPrefix(h, AnthropicAdvancedToolUseBetaHeaderPrefix): + if !features.AdvancedToolUse { + return nil, fmt.Errorf("beta header '%s' is not supported by provider '%s'", h, provider) + } + filtered = append(filtered, h) + default: + // Unknown headers are forwarded for forward compatibility + filtered = append(filtered, h) + } + } + return filtered, nil +} + // appendUniqueHeader adds a header to the slice if not already present func appendUniqueHeader(slice []string, item string) []string { for _, s := range slice { @@ -874,7 +1105,7 @@ func getImageURLFromBlock(block AnthropicContentBlock) string { // parseJSONInput returns a json.RawMessage that preserves the original key ordering // of the JSON input. This is critical for prompt caching, which relies on exact // byte-for-byte matching of the request prefix sent to providers. -func parseJSONInput(jsonStr string) interface{} { +func parseJSONInput(jsonStr string) json.RawMessage { if jsonStr == "" || jsonStr == "{}" { return json.RawMessage("{}") } @@ -885,8 +1116,8 @@ func parseJSONInput(jsonStr string) interface{} { return json.RawMessage(compacted) } - // If compaction fails (invalid JSON), return as string - return jsonStr + // If compaction fails (invalid JSON), return json.RawMessage of the raw string + return json.RawMessage(jsonStr) } // compactJSONBytes compacts JSON bytes, removing insignificant whitespace while @@ -1155,7 +1386,7 @@ func normalizeSchemaForAnthropic(schema map[string]interface{}) map[string]inter // "schema": {...}, // "strict": true // } -func convertChatResponseFormatToAnthropicOutputFormat(responseFormat *interface{}) interface{} { +func convertChatResponseFormatToAnthropicOutputFormat(responseFormat *interface{}) json.RawMessage { if responseFormat == nil { return nil } @@ -1189,7 +1420,11 @@ func convertChatResponseFormatToAnthropicOutputFormat(responseFormat *interface{ outputFormat["schema"] = normalizedSchema } - return outputFormat + result, err := providerUtils.MarshalSorted(outputFormat) + if err != nil { + return nil + } + return json.RawMessage(result) } // convertResponsesTextConfigToAnthropicOutputFormat converts OpenAI Responses API text config @@ -1212,7 +1447,7 @@ func convertChatResponseFormatToAnthropicOutputFormat(responseFormat *interface{ // "type": "json_schema", // "schema": {...} // } -func convertResponsesTextConfigToAnthropicOutputFormat(textConfig *schemas.ResponsesTextConfig) interface{} { +func convertResponsesTextConfigToAnthropicOutputFormat(textConfig *schemas.ResponsesTextConfig) json.RawMessage { if textConfig == nil || textConfig.Format == nil { return nil } @@ -1255,7 +1490,11 @@ func convertResponsesTextConfigToAnthropicOutputFormat(textConfig *schemas.Respo outputFormat["schema"] = normalizedSchema } - return outputFormat + result, err := providerUtils.MarshalSorted(outputFormat) + if err != nil { + return nil + } + return json.RawMessage(result) } // convertAnthropicOutputFormatToResponsesTextConfig converts Anthropic's output_format structure @@ -1280,14 +1519,14 @@ func convertResponsesTextConfigToAnthropicOutputFormat(textConfig *schemas.Respo // } // } // } -func convertAnthropicOutputFormatToResponsesTextConfig(outputFormat interface{}) *schemas.ResponsesTextConfig { +func convertAnthropicOutputFormatToResponsesTextConfig(outputFormat json.RawMessage) *schemas.ResponsesTextConfig { if outputFormat == nil { return nil } - // Try to convert to map - formatMap, ok := outputFormat.(map[string]interface{}) - if !ok { + // Unmarshal to map + var formatMap map[string]interface{} + if err := sonic.Unmarshal(outputFormat, &formatMap); err != nil { return nil } @@ -1494,7 +1733,7 @@ func convertAnthropicOutputFormatToResponsesTextConfig(outputFormat interface{}) // - If both arrays are empty, delete blocked_domains func sanitizeWebSearchArguments(argumentsJSON string) string { var toolArgs map[string]interface{} - if err := json.Unmarshal([]byte(argumentsJSON), &toolArgs); err != nil { + if err := sonic.Unmarshal([]byte(argumentsJSON), &toolArgs); err != nil { return argumentsJSON // Return original if parse fails } @@ -1529,7 +1768,7 @@ func sanitizeWebSearchArguments(argumentsJSON string) string { delete(toolArgs, shouldDelete) // Re-marshal the sanitized arguments - if sanitizedBytes, err := json.Marshal(toolArgs); err == nil { + if sanitizedBytes, err := providerUtils.MarshalSorted(toolArgs); err == nil { return string(sanitizedBytes) } } diff --git a/core/providers/anthropic/utils_test.go b/core/providers/anthropic/utils_test.go index 0b6f1eed46..db8429979f 100644 --- a/core/providers/anthropic/utils_test.go +++ b/core/providers/anthropic/utils_test.go @@ -1,10 +1,13 @@ package anthropic import ( + "encoding/json" "reflect" "testing" + "time" "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" ) func TestExtractTypesFromValue(t *testing.T) { @@ -568,3 +571,292 @@ func TestConvertChatResponseFormatToAnthropicOutputFormat(t *testing.T) { }) } } + +func TestValidateToolsForProvider(t *testing.T) { + tests := []struct { + name string + tools []schemas.ResponsesTool + provider schemas.ModelProvider + expectErr bool + }{ + { + name: "Anthropic allows web_search", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, + provider: schemas.Anthropic, + expectErr: false, + }, + { + name: "Anthropic allows web_fetch", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + provider: schemas.Anthropic, + expectErr: false, + }, + { + name: "Vertex allows web_search", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, + provider: schemas.Vertex, + expectErr: false, + }, + { + name: "Vertex rejects web_fetch", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + provider: schemas.Vertex, + expectErr: true, + }, + { + name: "Vertex rejects code_interpreter", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCodeInterpreter}}, + provider: schemas.Vertex, + expectErr: true, + }, + { + name: "Vertex rejects MCP", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMCP}}, + provider: schemas.Vertex, + expectErr: true, + }, + { + name: "Bedrock rejects web_search", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, + provider: schemas.Bedrock, + expectErr: true, + }, + { + name: "Bedrock rejects web_fetch", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + provider: schemas.Bedrock, + expectErr: true, + }, + { + name: "Bedrock allows computer_use", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeComputerUsePreview}}, + provider: schemas.Bedrock, + expectErr: false, + }, + { + name: "Azure allows everything", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}, {Type: schemas.ResponsesToolTypeCodeInterpreter}, {Type: schemas.ResponsesToolTypeMCP}}, + provider: schemas.Azure, + expectErr: false, + }, + { + name: "Unknown provider allows all", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, + provider: "custom_provider", + expectErr: false, + }, + { + name: "Function tools always allowed", + tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeFunction}}, + provider: schemas.Bedrock, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateToolsForProvider(tt.tools, tt.provider) + if tt.expectErr && err == nil { + t.Errorf("expected error but got nil") + } + if !tt.expectErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestAddMissingBetaHeadersToContext_PerProvider(t *testing.T) { + tests := []struct { + name string + provider schemas.ModelProvider + req *AnthropicMessageRequest + expectHeaders []string + unexpectHeaders []string + }{ + { + name: "Anthropic gets structured outputs header", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + OutputFormat: json.RawMessage(`{"type":"json_schema"}`), + }, + expectHeaders: []string{AnthropicStructuredOutputsBetaHeader}, + }, + { + name: "Vertex skips structured outputs header", + provider: schemas.Vertex, + req: &AnthropicMessageRequest{ + OutputFormat: json.RawMessage(`{"type":"json_schema"}`), + }, + unexpectHeaders: []string{AnthropicStructuredOutputsBetaHeader}, + }, + { + name: "Vertex skips MCP header", + provider: schemas.Vertex, + req: &AnthropicMessageRequest{ + MCPServers: []AnthropicMCPServer{{URL: "http://example.com"}}, + }, + unexpectHeaders: []string{AnthropicMCPClientBetaHeader}, + }, + { + name: "Anthropic gets MCP header", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + MCPServers: []AnthropicMCPServer{{URL: "http://example.com"}}, + }, + expectHeaders: []string{AnthropicMCPClientBetaHeader}, + }, + { + name: "Vertex gets compaction header", + provider: schemas.Vertex, + req: &AnthropicMessageRequest{ + ContextManagement: &ContextManagement{ + Edits: []ContextManagementEdit{{Type: ContextManagementEditTypeCompact}}, + }, + }, + expectHeaders: []string{AnthropicCompactionBetaHeader}, + }, + { + name: "Bedrock gets compaction header", + provider: schemas.Bedrock, + req: &AnthropicMessageRequest{ + ContextManagement: &ContextManagement{ + Edits: []ContextManagementEdit{{Type: ContextManagementEditTypeCompact}}, + }, + }, + expectHeaders: []string{AnthropicCompactionBetaHeader}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := schemas.NewBifrostContext(nil, time.Time{}) + AddMissingBetaHeadersToContext(ctx, tt.req, tt.provider) + + var headers []string + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + headers = extraHeaders["anthropic-beta"] + } + + for _, expected := range tt.expectHeaders { + found := false + for _, h := range headers { + if h == expected { + found = true + break + } + } + if !found { + t.Errorf("expected header %q not found in %v", expected, headers) + } + } + + for _, unexpected := range tt.unexpectHeaders { + for _, h := range headers { + if h == unexpected { + t.Errorf("unexpected header %q found in %v", unexpected, headers) + } + } + } + }) + } +} + +func TestFilterBetaHeadersForProvider(t *testing.T) { + allHeaders := []string{ + AnthropicComputerUseBetaHeader20251124, + AnthropicStructuredOutputsBetaHeader, + AnthropicMCPClientBetaHeader, + AnthropicPromptCachingScopeBetaHeader, + AnthropicCompactionBetaHeader, + AnthropicContextManagementBetaHeader, + AnthropicAdvancedToolUseBetaHeader, + AnthropicFilesAPIBetaHeader, + } + + t.Run("Anthropic/keeps_all_headers", func(t *testing.T) { + result, err := FilterBetaHeadersForProvider(allHeaders, schemas.Anthropic) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, h := range allHeaders { + found := false + for _, r := range result { + if r == h { + found = true + break + } + } + if !found { + t.Errorf("expected header %q to be kept for Anthropic, got %v", h, result) + } + } + }) + + t.Run("Vertex/errors_on_unsupported_headers", func(t *testing.T) { + unsupported := []string{ + AnthropicStructuredOutputsBetaHeader, + AnthropicMCPClientBetaHeader, + AnthropicPromptCachingScopeBetaHeader, + AnthropicAdvancedToolUseBetaHeader, + AnthropicFilesAPIBetaHeader, + } + for _, h := range unsupported { + _, err := FilterBetaHeadersForProvider([]string{h}, schemas.Vertex) + if err == nil { + t.Errorf("expected error for header %q on Vertex, got nil", h) + } + } + }) + + t.Run("Vertex/allows_supported_headers", func(t *testing.T) { + supported := []string{ + AnthropicComputerUseBetaHeader20251124, + AnthropicCompactionBetaHeader, + AnthropicContextManagementBetaHeader, + } + result, err := FilterBetaHeadersForProvider(supported, schemas.Vertex) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != len(supported) { + t.Errorf("expected %d headers, got %d: %v", len(supported), len(result), result) + } + }) + + t.Run("Bedrock/errors_on_unsupported_headers", func(t *testing.T) { + unsupported := []string{ + AnthropicMCPClientBetaHeader, + AnthropicPromptCachingScopeBetaHeader, + AnthropicAdvancedToolUseBetaHeader, + AnthropicFilesAPIBetaHeader, + } + for _, h := range unsupported { + _, err := FilterBetaHeadersForProvider([]string{h}, schemas.Bedrock) + if err == nil { + t.Errorf("expected error for header %q on Bedrock, got nil", h) + } + } + }) + + t.Run("unknown_headers_forwarded", func(t *testing.T) { + headers := []string{"some-future-beta-2025"} + result, err := FilterBetaHeadersForProvider(headers, schemas.Vertex) + if err != nil { + t.Fatalf("unexpected error for unknown headers: %v", err) + } + if len(result) != len(headers) { + t.Errorf("expected all unknown headers to be forwarded, got %v", result) + } + }) + + t.Run("unknown_provider_allows_all", func(t *testing.T) { + result, err := FilterBetaHeadersForProvider(allHeaders, schemas.ModelProvider("custom-provider")) + if err != nil { + t.Fatalf("unexpected error for unknown provider: %v", err) + } + if len(result) != len(allHeaders) { + t.Errorf("expected all headers for unknown provider, got %v", result) + } + }) +} diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 474f933168..1d71d585fe 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -554,6 +554,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } if reqBody != nil { reqBody.Model = deployment + // Add provider-aware beta headers for Azure + anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } return reqBody, nil } else { @@ -681,6 +683,8 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, if reqBody != nil { reqBody.Model = deployment reqBody.Stream = schemas.Ptr(true) + // Add provider-aware beta headers for Azure + anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } return reqBody, nil }, diff --git a/core/providers/azure/utils.go b/core/providers/azure/utils.go index d4f982aa56..acc0e0ffa1 100644 --- a/core/providers/azure/utils.go +++ b/core/providers/azure/utils.go @@ -62,6 +62,9 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s reqBody.Stream = schemas.Ptr(true) } + // Add provider-aware beta headers for Azure + anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) + // Marshal struct to JSON bytes, preserving field order jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { diff --git a/core/providers/bedrock/batch.go b/core/providers/bedrock/batch.go index 8fbc92ba21..18ebad9a90 100644 --- a/core/providers/bedrock/batch.go +++ b/core/providers/bedrock/batch.go @@ -1,6 +1,7 @@ package bedrock import ( + "encoding/json" "fmt" "time" @@ -89,7 +90,7 @@ type BedrockBatchJobSummary struct { // BedrockBatchResultRecord represents a single result record in Bedrock batch output JSONL. type BedrockBatchResultRecord struct { RecordID string `json:"recordId"` - ModelOutput map[string]interface{} `json:"modelOutput,omitempty"` + ModelOutput json.RawMessage `json:"modelOutput,omitempty"` Error *BedrockBatchError `json:"error,omitempty"` } @@ -165,9 +166,17 @@ func parseBatchResultsJSONL(content []byte, provider *BedrockProvider) ([]schema } if bedrockResult.ModelOutput != nil { - resultItem.Response = &schemas.BatchResultResponse{ - StatusCode: 200, - Body: bedrockResult.ModelOutput, + var bodyMap map[string]interface{} + if err := sonic.Unmarshal(bedrockResult.ModelOutput, &bodyMap); err == nil { + resultItem.Response = &schemas.BatchResultResponse{ + StatusCode: 200, + Body: bodyMap, + } + } else { + resultItem.Error = &schemas.BatchResultError{ + Code: "parse_error", + Message: fmt.Sprintf("failed to parse model output: %v", err), + } } } diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index 1ff37eeb83..b3bc5bf421 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -14,6 +14,36 @@ import ( "github.com/stretchr/testify/require" ) +func mustMarshalJSON(v interface{}) json.RawMessage { + b, _ := json.Marshal(v) + return json.RawMessage(b) +} + +// jsonEqual compares two json.RawMessage values semantically (ignoring key order). +func jsonEqual(t *testing.T, expected, actual json.RawMessage, msgAndArgs ...interface{}) { + t.Helper() + if expected == nil && actual == nil { + return + } + var e, a interface{} + if err := json.Unmarshal(expected, &e); err != nil { + t.Errorf("failed to unmarshal expected JSON: %v", err) + return + } + if err := json.Unmarshal(actual, &a); err != nil { + t.Errorf("failed to unmarshal actual JSON: %v", err) + return + } + assert.Equal(t, e, a, msgAndArgs...) +} + +// mustMarshalToolParams marshals ToolFunctionParameters to json.RawMessage, +// matching the conversion code path for deterministic output. +func mustMarshalToolParams(params *schemas.ToolFunctionParameters) json.RawMessage { + b, _ := json.Marshal(params) + return json.RawMessage(b) +} + // Common test variables var ( testMaxTokens = 100 @@ -28,6 +58,14 @@ var ( "description": "The city name", }), ) + // testPropsFromJSON is the same as testProps but with nested values as *OrderedMap + // (as produced by json.Unmarshal -> OrderedMap.UnmarshalJSON) + testPropsFromJSON = *schemas.NewOrderedMapFromPairs( + schemas.KV("location", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "string"), + schemas.KV("description", "The city name"), + )), + ) ) // assertBedrockRequestEqual compares two BedrockConverseRequest objects @@ -433,11 +471,11 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) { Name: "get_weather", Description: schemas.Ptr("Get weather information"), InputSchema: bedrock.BedrockToolInputSchema{ - JSON: map[string]interface{}{ - "type": "object", - "properties": &props, - "required": []string{"location"}, - }, + JSON: mustMarshalToolParams(&schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }), }, }, }, @@ -644,10 +682,10 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) { Name: "hello", Description: schemas.Ptr("Tool extracted from conversation history"), InputSchema: bedrock.BedrockToolInputSchema{ - JSON: map[string]interface{}{ + JSON: mustMarshalJSON(map[string]interface{}{ "type": "object", "properties": map[string]interface{}{}, - }, + }), }, }, }, @@ -656,10 +694,10 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) { Name: "world", Description: schemas.Ptr("Tool extracted from conversation history"), InputSchema: bedrock.BedrockToolInputSchema{ - JSON: map[string]interface{}{ + JSON: mustMarshalJSON(map[string]interface{}{ "type": "object", "properties": map[string]interface{}{}, - }, + }), }, }, }, @@ -779,12 +817,12 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) { ToolUseID: "tooluse_Yl388l8ES0G_3TQtDcKq_g", Content: []bedrock.BedrockContentBlock{ { - JSON: map[string]any{ + JSON: mustMarshalJSON(map[string]any{ "results": []any{ any(map[string]any{"period": "now", "weather": "sunny"}), any(map[string]any{"period": "next_1_hour", "weather": "cloudy"}), }, - }, + }), }, }, Status: schemas.Ptr("success"), @@ -801,11 +839,11 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) { Name: "get_weather", Description: schemas.Ptr("Get weather information"), InputSchema: bedrock.BedrockToolInputSchema{ - JSON: map[string]interface{}{ - "type": "object", - "properties": &props, - "required": []string{"location"}, - }, + JSON: mustMarshalToolParams(&schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }), }, }, }, @@ -827,11 +865,7 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) { } } else { require.NoError(t, err) - if tt.name == "ParallelToolCalls" { - assertBedrockRequestEqual(t, tt.expected, actual) - } else { - assert.Equal(t, tt.expected, actual) - } + assertBedrockRequestEqual(t, tt.expected, actual) } }) } @@ -845,6 +879,22 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { trace := testTrace latency := testLatency props := testProps + _ = props // used in input construction + + // Build expected params via JSON round-trip so keyOrder and nested OrderedMap match + expectedParamsJSON := mustMarshalJSON(map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + }, + "required": []string{"location"}, + }) + var expectedParams schemas.ToolFunctionParameters + _ = json.Unmarshal(expectedParamsJSON, &expectedParams) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) tests := []struct { @@ -1057,7 +1107,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Name: "get_weather", Description: schemas.Ptr("Get weather information"), InputSchema: bedrock.BedrockToolInputSchema{ - JSON: map[string]interface{}{ + JSON: mustMarshalJSON(map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "location": map[string]interface{}{ @@ -1066,7 +1116,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { }, }, "required": []string{"location"}, - }, + }), }, }, }, @@ -1098,11 +1148,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { Name: schemas.Ptr("get_weather"), Description: schemas.Ptr("Get weather information"), ResponsesToolFunction: &schemas.ResponsesToolFunction{ - Parameters: &schemas.ToolFunctionParameters{ - Type: "object", - Properties: &props, - Required: []string{"location"}, - }, + Parameters: &expectedParams, }, }, }, @@ -1513,7 +1559,7 @@ func TestBifrostToBedrockResponseConversion(t *testing.T) { ToolUse: &bedrock.BedrockToolUse{ ToolUseID: callID, Name: toolName, - Input: "invalid json {", // Should fallback to raw string + Input: json.RawMessage("invalid json {"), // Should fallback to raw string }, }, }, @@ -1707,10 +1753,10 @@ func TestBifrostToBedrockResponseConversion(t *testing.T) { Status: schemas.Ptr("success"), Content: []bedrock.BedrockContentBlock{ { - JSON: map[string]interface{}{ + JSON: mustMarshalJSON(map[string]interface{}{ "temperature": float64(72), "location": "NYC", - }, + }), }, }, }, @@ -1813,9 +1859,9 @@ func TestBifrostToBedrockResponseConversion(t *testing.T) { Status: schemas.Ptr("success"), Content: []bedrock.BedrockContentBlock{ { - JSON: map[string]interface{}{ + JSON: mustMarshalJSON(map[string]interface{}{ "temperature": float64(72), - }, + }), }, }, }, @@ -1912,9 +1958,7 @@ func TestBedrockToBifrostResponseConversion(t *testing.T) { totalTokens := 30 toolUseID := "call-123" toolName := "get_weather" - toolInput := map[string]interface{}{ - "location": "NYC", - } + toolInput := json.RawMessage(`{"location":"NYC"}`) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) tests := []struct { @@ -2035,7 +2079,7 @@ func TestBedrockToBifrostResponseConversion(t *testing.T) { ResponsesToolMessage: &schemas.ResponsesToolMessage{ CallID: &toolUseID, Name: &toolName, - Arguments: schemas.Ptr(schemas.JsonifyInput(toolInput)), + Arguments: schemas.Ptr(string(toolInput)), }, }, }, @@ -2175,7 +2219,7 @@ func TestToolResultJSONParsingResponsesAPI(t *testing.T) { name string toolResultContent string expectedContentType string // "text" or "json" - expectedJSON map[string]any + expectedJSON json.RawMessage expectedText *string }{ { @@ -2194,54 +2238,54 @@ func TestToolResultJSONParsingResponsesAPI(t *testing.T) { name: "JSONObjectResult", toolResultContent: `{"location":"NYC","temperature":72}`, expectedContentType: "json", - expectedJSON: map[string]any{"location": "NYC", "temperature": float64(72)}, + expectedJSON: mustMarshalJSON(map[string]any{"location": "NYC", "temperature": float64(72)}), }, { name: "JSONArrayResult", toolResultContent: `[{"period":"now","weather":"sunny"},{"period":"next_1_hour","weather":"cloudy"}]`, expectedContentType: "json", - expectedJSON: map[string]any{ + expectedJSON: mustMarshalJSON(map[string]any{ "results": []any{ map[string]any{"period": "now", "weather": "sunny"}, map[string]any{"period": "next_1_hour", "weather": "cloudy"}, }, - }, + }), }, { name: "JSONPrimitiveNumberResult", toolResultContent: `42`, expectedContentType: "json", - expectedJSON: map[string]any{"value": float64(42)}, + expectedJSON: mustMarshalJSON(map[string]any{"value": float64(42)}), }, { name: "JSONPrimitiveStringResult", toolResultContent: `"hello world"`, expectedContentType: "json", - expectedJSON: map[string]any{"value": "hello world"}, + expectedJSON: mustMarshalJSON(map[string]any{"value": "hello world"}), }, { name: "JSONPrimitiveBooleanResult", toolResultContent: `true`, expectedContentType: "json", - expectedJSON: map[string]any{"value": true}, + expectedJSON: mustMarshalJSON(map[string]any{"value": true}), }, { name: "JSONPrimitiveNullResult", toolResultContent: `null`, expectedContentType: "json", - expectedJSON: map[string]any{"value": nil}, + expectedJSON: mustMarshalJSON(map[string]any{"value": nil}), }, { name: "EmptyJSONObjectResult", toolResultContent: `{}`, expectedContentType: "json", - expectedJSON: map[string]any{}, + expectedJSON: mustMarshalJSON(map[string]any{}), }, { name: "EmptyJSONArrayResult", toolResultContent: `[]`, expectedContentType: "json", - expectedJSON: map[string]any{"results": []any{}}, + expectedJSON: mustMarshalJSON(map[string]any{"results": []any{}}), }, } diff --git a/core/providers/bedrock/chat.go b/core/providers/bedrock/chat.go index 1b747592f7..9e92f9d0cb 100644 --- a/core/providers/bedrock/chat.go +++ b/core/providers/bedrock/chat.go @@ -5,7 +5,6 @@ import ( "fmt" "time" - "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/maximhq/bifrost/core/schemas" ) @@ -74,13 +73,8 @@ func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Conte if structuredOutputToolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok && contentBlock.ToolUse.Name == structuredOutputToolName { // This is structured output - set contentStr and skip adding to toolCalls if contentBlock.ToolUse.Input != nil { - if argBytes, err := sonic.Marshal(contentBlock.ToolUse.Input); err == nil { - jsonStr := string(argBytes) - contentStr = &jsonStr - } else { - jsonStr := fmt.Sprintf("%v", contentBlock.ToolUse.Input) - contentStr = &jsonStr - } + jsonStr := string(contentBlock.ToolUse.Input) + contentStr = &jsonStr } continue // Skip adding to toolCalls } @@ -88,11 +82,7 @@ func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Conte // Regular tool call processing var arguments string if contentBlock.ToolUse.Input != nil { - if argBytes, err := sonic.Marshal(contentBlock.ToolUse.Input); err == nil { - arguments = string(argBytes) - } else { - arguments = fmt.Sprintf("%v", contentBlock.ToolUse.Input) - } + arguments = string(contentBlock.ToolUse.Input) } else { arguments = "{}" } diff --git a/core/providers/bedrock/invoke.go b/core/providers/bedrock/invoke.go index 44174cae6f..601edd5226 100644 --- a/core/providers/bedrock/invoke.go +++ b/core/providers/bedrock/invoke.go @@ -8,6 +8,7 @@ import ( "github.com/bytedance/sonic" "github.com/google/uuid" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -345,7 +346,7 @@ func (r *BedrockInvokeRequest) parseSystemMessages() []BedrockSystemMessage { for _, item := range s { if m, ok := item.(map[string]interface{}); ok { // Re-marshal and unmarshal to capture all fields (text, guardContent, cachePoint) - itemBytes, err := sonic.Marshal(m) + itemBytes, err := providerUtils.MarshalSorted(m) if err != nil { continue } @@ -390,7 +391,8 @@ func (r *BedrockInvokeRequest) convertAnthropicTools() *BedrockToolConfig { spec.Description = &desc } if inputSchema, ok := toolMap["input_schema"]; ok { - spec.InputSchema = BedrockToolInputSchema{JSON: inputSchema} + inputSchemaBytes, _ := providerUtils.MarshalSorted(inputSchema) + spec.InputSchema = BedrockToolInputSchema{JSON: json.RawMessage(inputSchemaBytes)} } bedrockTools = append(bedrockTools, BedrockTool{ToolSpec: spec}) @@ -857,7 +859,7 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) return nil, nil } - bytes, err := sonic.Marshal(event) + bytes, err := providerUtils.MarshalSorted(event) if err != nil { return nil, fmt.Errorf("failed to marshal invoke stream event: %w", err) } diff --git a/core/providers/bedrock/responses.go b/core/providers/bedrock/responses.go index 256e33000e..f89b72ef20 100644 --- a/core/providers/bedrock/responses.go +++ b/core/providers/bedrock/responses.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/maximhq/bifrost/core/providers/anthropic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -1419,12 +1420,18 @@ func (request *BedrockConverseRequest) ToBifrostResponsesRequest(ctx *schemas.Bi } // Handle different types for InputSchema.JSON - if params, ok := tool.ToolSpec.InputSchema.JSON.(*schemas.ToolFunctionParameters); ok { - bifrostTool.ResponsesToolFunction.Parameters = params - } else if paramsMap, ok := tool.ToolSpec.InputSchema.JSON.(map[string]interface{}); ok { - // Convert map to ToolFunctionParameters - params := convertMapToToolFunctionParameters(paramsMap) - bifrostTool.ResponsesToolFunction.Parameters = params + if len(tool.ToolSpec.InputSchema.JSON) > 0 { + var params schemas.ToolFunctionParameters + if err := sonic.Unmarshal(tool.ToolSpec.InputSchema.JSON, ¶ms); err == nil { + bifrostTool.ResponsesToolFunction.Parameters = ¶ms + } else { + // Fallback: unmarshal as map and convert + var paramsMap map[string]interface{} + if err := sonic.Unmarshal(tool.ToolSpec.InputSchema.JSON, ¶msMap); err == nil { + params := convertMapToToolFunctionParameters(paramsMap) + bifrostTool.ResponsesToolFunction.Parameters = params + } + } } bifrostReq.Params.Tools = append(bifrostReq.Params.Tools, bifrostTool) @@ -1634,6 +1641,13 @@ func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas. return nil, fmt.Errorf("bifrost request is nil") } + // Validate tools are supported by Bedrock + if bifrostReq.Params != nil && bifrostReq.Params.Tools != nil { + if toolErr := anthropic.ValidateToolsForProvider(bifrostReq.Params.Tools, schemas.Bedrock); toolErr != nil { + return nil, toolErr + } + } + bedrockReq := &BedrockConverseRequest{ ModelID: bifrostReq.Model, } @@ -1899,12 +1913,16 @@ func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas. description = *tool.Description } + schemaObjectBytes, err := providerUtils.MarshalSorted(schemaObject) + if err != nil { + return nil, fmt.Errorf("failed to serialize tool schema %q: %w", name, err) + } bedrockTool := BedrockTool{ ToolSpec: &BedrockToolSpec{ Name: name, Description: &description, InputSchema: BedrockToolInputSchema{ - JSON: schemaObject, + JSON: json.RawMessage(schemaObjectBytes), }, }, } @@ -2132,12 +2150,13 @@ func extractToolsFromResponsesConversationHistory(messages []schemas.ResponsesMe description = *tool.Description } + schemaObjectBytes2, _ := providerUtils.MarshalSorted(schemaObject) bedrockTool := BedrockTool{ ToolSpec: &BedrockToolSpec{ Name: *tool.Name, Description: &description, InputSchema: BedrockToolInputSchema{ - JSON: schemaObject, + JSON: json.RawMessage(schemaObjectBytes2), }, }, } @@ -2453,10 +2472,10 @@ func ConvertBifrostMessagesToBedrockMessages(bifrostMessages []schemas.Responses }, } // Preserve original key ordering of tool arguments for prompt caching. - var input interface{} + var input json.RawMessage var buf bytes.Buffer if err := json.Compact(&buf, []byte(toolCall.Arguments)); err == nil { - input = json.RawMessage(buf.Bytes()) + input = buf.Bytes() } else { input = json.RawMessage("{}") } @@ -2583,10 +2602,10 @@ func ConvertBifrostMessagesToBedrockMessages(bifrostMessages []schemas.Responses }, } // Preserve original key ordering of tool arguments for prompt caching. - var input interface{} + var input json.RawMessage var buf bytes.Buffer if err := json.Compact(&buf, []byte(toolCall.Arguments)); err == nil { - input = json.RawMessage(buf.Bytes()) + input = buf.Bytes() } else { input = json.RawMessage("{}") } @@ -2664,10 +2683,10 @@ func ConvertBifrostMessagesToBedrockMessages(bifrostMessages []schemas.Responses }, } // Preserve original key ordering of tool arguments for prompt caching. - var input interface{} + var input json.RawMessage var buf bytes.Buffer if err := json.Compact(&buf, []byte(toolCall.Arguments)); err == nil { - input = json.RawMessage(buf.Bytes()) + input = buf.Bytes() } else { input = json.RawMessage("{}") } @@ -2980,7 +2999,7 @@ func convertSingleBedrockMessageToBifrostMessages(ctx *schemas.BifrostContext, m // Marshal the tool input to JSON string var contentStr string if block.ToolUse.Input != nil { - contentStr = schemas.JsonifyInput(block.ToolUse.Input) + contentStr = string(block.ToolUse.Input) } else { contentStr = "{}" } @@ -2992,13 +3011,17 @@ func convertSingleBedrockMessageToBifrostMessages(ctx *schemas.BifrostContext, m outputMessages = append(outputMessages, bifrostMsg) } else { // Normal tool call message + arguments := "{}" + if block.ToolUse.Input != nil { + arguments = string(block.ToolUse.Input) + } toolMsg := schemas.ResponsesMessage{ Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), Status: schemas.Ptr("completed"), ResponsesToolMessage: &schemas.ResponsesToolMessage{ CallID: &toolUseID, Name: &toolUseName, - Arguments: schemas.Ptr(schemas.JsonifyInput(block.ToolUse.Input)), + Arguments: schemas.Ptr(arguments), }, } if isOutputMessage { @@ -3073,7 +3096,7 @@ func convertSingleBedrockMessageToBifrostMessages(ctx *schemas.BifrostContext, m // JSON first (no unmarshal; just one marshal to string when present) for _, c := range block.ToolResult.Content { if c.JSON != nil { - resultContent = schemas.JsonifyInput(c.JSON) + resultContent = string(c.JSON) break } } diff --git a/core/providers/bedrock/types.go b/core/providers/bedrock/types.go index 14c1d45f72..9cac894ac6 100644 --- a/core/providers/bedrock/types.go +++ b/core/providers/bedrock/types.go @@ -196,7 +196,7 @@ type BedrockContentBlock struct { ReasoningContent *BedrockReasoningContent `json:"reasoningContent,omitempty"` // For Tool Call Result content - JSON interface{} `json:"json,omitempty"` + JSON json.RawMessage `json:"json,omitempty"` // Cache point for the content block CachePoint *BedrockCachePoint `json:"cachePoint,omitempty"` @@ -241,7 +241,7 @@ type BedrockDocumentSourceData struct { type BedrockToolUse struct { ToolUseID string `json:"toolUseId"` // Required: Unique identifier for this tool use Name string `json:"name"` // Required: Name of the tool to use - Input interface{} `json:"input"` // Required: Input parameters for the tool (JSON object) + Input json.RawMessage `json:"input"` // Required: Input parameters for the tool (json.RawMessage preserves key ordering for prompt caching) } // BedrockToolResult represents the result of a tool use @@ -309,7 +309,7 @@ type BedrockToolSpec struct { // BedrockToolInputSchema represents the input schema for a tool (union type) type BedrockToolInputSchema struct { - JSON interface{} `json:"json,omitempty"` // The JSON schema for the tool + JSON json.RawMessage `json:"json,omitempty"` // The JSON schema for the tool } // BedrockToolChoice represents tool choice configuration @@ -374,7 +374,7 @@ type BedrockConverseResponse struct { StopReason string `json:"stopReason"` // Required: Reason for stopping Usage *BedrockTokenUsage `json:"usage"` // Required: Token usage information Metrics *BedrockConverseMetrics `json:"metrics"` // Required: Response metrics - AdditionalModelResponseFields map[string]interface{} `json:"additionalModelResponseFields,omitempty"` // Optional: Additional model-specific response fields + AdditionalModelResponseFields json.RawMessage `json:"additionalModelResponseFields,omitempty"` // Optional: Additional model-specific response fields (json.RawMessage preserves key ordering) PerformanceConfig *BedrockPerformanceConfig `json:"performanceConfig,omitempty"` // Optional: Performance configuration used ServiceTier *BedrockServiceTier `json:"serviceTier,omitempty"` // Optional: Service tier that was used Trace *BedrockConverseTrace `json:"trace,omitempty"` // Optional: Guardrail trace information diff --git a/core/providers/bedrock/utils.go b/core/providers/bedrock/utils.go index 5ebbf69964..b698262869 100644 --- a/core/providers/bedrock/utils.go +++ b/core/providers/bedrock/utils.go @@ -8,7 +8,6 @@ import ( "regexp" "strings" - "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/providers/anthropic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" @@ -505,31 +504,38 @@ func convertToolMessages(msgs []schemas.ChatMessage) (BedrockMessage, error) { var toolResultContent []BedrockContentBlock if msg.Content.ContentStr != nil { // Bedrock expects JSON to be a parsed object, not a string - // Try to unmarshal the string content as JSON - var parsedOutput interface{} - if err := json.Unmarshal([]byte(*msg.Content.ContentStr), &parsedOutput); err != nil { + // Validate and compact JSON without parsing into Go types (preserves key ordering) + var buf bytes.Buffer + if err := json.Compact(&buf, []byte(*msg.Content.ContentStr)); err != nil { // If it's not valid JSON, wrap it as a text block instead toolResultContent = append(toolResultContent, BedrockContentBlock{ Text: msg.Content.ContentStr, }) } else { - // Use the parsed JSON object + compacted := buf.Bytes() // Bedrock does not accept primitives or arrays directly in the json field - switch v := parsedOutput.(type) { - case map[string]any: + if len(compacted) > 0 && compacted[0] == '{' { // Objects are valid as-is toolResultContent = append(toolResultContent, BedrockContentBlock{ - JSON: v, + JSON: json.RawMessage(compacted), }) - case []any: + } else if len(compacted) > 0 && compacted[0] == '[' { // Arrays need to be wrapped + wrapped := make([]byte, 0, len(compacted)+len(`{"results":}`)) + wrapped = append(wrapped, `{"results":`...) + wrapped = append(wrapped, compacted...) + wrapped = append(wrapped, '}') toolResultContent = append(toolResultContent, BedrockContentBlock{ - JSON: map[string]any{"results": v}, + JSON: json.RawMessage(wrapped), }) - default: + } else { // Primitives (string, number, boolean, null) need to be wrapped + wrapped := make([]byte, 0, len(compacted)+len(`{"value":}`)) + wrapped = append(wrapped, `{"value":`...) + wrapped = append(wrapped, compacted...) + wrapped = append(wrapped, '}') toolResultContent = append(toolResultContent, BedrockContentBlock{ - JSON: map[string]any{"value": v}, + JSON: json.RawMessage(wrapped), }) } } @@ -862,12 +868,16 @@ func convertResponseFormatToTool(ctx *schemas.BifrostContext, params *schemas.Ch ctx.SetValue(schemas.BifrostContextKeyStructuredOutputToolName, toolName) // Create the Bedrock tool + schemaObjBytes, err := providerUtils.MarshalSorted(schemaObj) + if err != nil { + return nil + } return &BedrockTool{ ToolSpec: &BedrockToolSpec{ Name: toolName, Description: schemas.Ptr(description), InputSchema: BedrockToolInputSchema{ - JSON: schemaObj, + JSON: json.RawMessage(schemaObjBytes), }, }, } @@ -904,12 +914,16 @@ func convertTextFormatToTool(ctx *schemas.BifrostContext, textConfig *schemas.Re return nil // Schema is required for Bedrock tooling } + schemaObjBytes2, err := providerUtils.MarshalSorted(schemaObj) + if err != nil { + return nil + } return &BedrockTool{ ToolSpec: &BedrockToolSpec{ Name: toolName, Description: schemas.Ptr(description), InputSchema: BedrockToolInputSchema{ - JSON: schemaObj, + JSON: json.RawMessage(schemaObjBytes2), }, }, } @@ -946,99 +960,19 @@ func convertToolConfig(model string, params *schemas.ChatParameters) *BedrockToo var bedrockTools []BedrockTool for _, tool := range params.Tools { if tool.Function != nil { - // Create the complete schema object that Bedrock expects - var schemaObject interface{} + // Serialize the parameters (or a default empty schema) to json.RawMessage + var schemaObjectBytes []byte if tool.Function.Parameters != nil { - // Use the complete parameters object which includes type, properties, required, etc. - schemaMap := map[string]interface{}{ - "type": tool.Function.Parameters.Type, - } - if tool.Function.Parameters.Properties != nil { - schemaMap["properties"] = tool.Function.Parameters.Properties - } - // Add required field if present - if len(tool.Function.Parameters.Required) > 0 { - schemaMap["required"] = tool.Function.Parameters.Required - } - // Add description if present - if tool.Function.Parameters.Description != nil { - schemaMap["description"] = *tool.Function.Parameters.Description - } - // Add enum if present - if len(tool.Function.Parameters.Enum) > 0 { - schemaMap["enum"] = tool.Function.Parameters.Enum - } - // Add additionalProperties if present - if tool.Function.Parameters.AdditionalProperties != nil { - schemaMap["additionalProperties"] = tool.Function.Parameters.AdditionalProperties - } - // Add JSON Schema definition fields - if tool.Function.Parameters.Defs != nil { - schemaMap["$defs"] = tool.Function.Parameters.Defs - } - if tool.Function.Parameters.Definitions != nil { - schemaMap["definitions"] = tool.Function.Parameters.Definitions - } - if tool.Function.Parameters.Ref != nil { - schemaMap["$ref"] = *tool.Function.Parameters.Ref - } - // Add array schema fields - if tool.Function.Parameters.Items != nil { - schemaMap["items"] = tool.Function.Parameters.Items - } - if tool.Function.Parameters.MinItems != nil { - schemaMap["minItems"] = *tool.Function.Parameters.MinItems - } - if tool.Function.Parameters.MaxItems != nil { - schemaMap["maxItems"] = *tool.Function.Parameters.MaxItems + // ToolFunctionParameters.MarshalJSON handles all fields including + // properties, required, enum, additionalProperties, $defs, etc. + var err error + schemaObjectBytes, err = providerUtils.MarshalSorted(tool.Function.Parameters) + if err != nil { + continue } - // Add composition fields - if len(tool.Function.Parameters.AnyOf) > 0 { - schemaMap["anyOf"] = tool.Function.Parameters.AnyOf - } - if len(tool.Function.Parameters.OneOf) > 0 { - schemaMap["oneOf"] = tool.Function.Parameters.OneOf - } - if len(tool.Function.Parameters.AllOf) > 0 { - schemaMap["allOf"] = tool.Function.Parameters.AllOf - } - // Add string validation fields - if tool.Function.Parameters.Format != nil { - schemaMap["format"] = *tool.Function.Parameters.Format - } - if tool.Function.Parameters.Pattern != nil { - schemaMap["pattern"] = *tool.Function.Parameters.Pattern - } - if tool.Function.Parameters.MinLength != nil { - schemaMap["minLength"] = *tool.Function.Parameters.MinLength - } - if tool.Function.Parameters.MaxLength != nil { - schemaMap["maxLength"] = *tool.Function.Parameters.MaxLength - } - // Add number validation fields - if tool.Function.Parameters.Minimum != nil { - schemaMap["minimum"] = *tool.Function.Parameters.Minimum - } - if tool.Function.Parameters.Maximum != nil { - schemaMap["maximum"] = *tool.Function.Parameters.Maximum - } - // Add misc fields - if tool.Function.Parameters.Title != nil { - schemaMap["title"] = *tool.Function.Parameters.Title - } - if tool.Function.Parameters.Default != nil { - schemaMap["default"] = tool.Function.Parameters.Default - } - if tool.Function.Parameters.Nullable != nil { - schemaMap["nullable"] = *tool.Function.Parameters.Nullable - } - schemaObject = schemaMap } else { // Fallback to empty object schema if no parameters - schemaObject = map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{}, - } + schemaObjectBytes = []byte(`{"type":"object","properties":{}}`) } // Use the tool description if available, otherwise use a generic description @@ -1052,7 +986,7 @@ func convertToolConfig(model string, params *schemas.ChatParameters) *BedrockToo Name: tool.Function.Name, Description: schemas.Ptr(description), InputSchema: BedrockToolInputSchema{ - JSON: schemaObject, + JSON: json.RawMessage(schemaObjectBytes), }, }, } @@ -1156,13 +1090,14 @@ func checkMessageForToolContent(msg schemas.ChatMessage, toolsMap map[string]Bed "type": "object", "properties": map[string]interface{}{}, } + extractedSchemaBytes, _ := providerUtils.MarshalSorted(schemaObject) toolsMap[*toolCall.Function.Name] = BedrockTool{ ToolSpec: &BedrockToolSpec{ Name: *toolCall.Function.Name, Description: schemas.Ptr("Tool extracted from conversation history"), InputSchema: BedrockToolInputSchema{ - JSON: schemaObject, + JSON: json.RawMessage(extractedSchemaBytes), }, }, } @@ -1203,10 +1138,10 @@ func convertToolCallToContentBlock(toolCall schemas.ChatAssistantMessageToolCall // Preserve original key ordering of tool arguments for prompt caching. // Using json.RawMessage avoids the map[string]interface{} round-trip // that would destroy key order. - var input interface{} + var input json.RawMessage var buf bytes.Buffer if err := json.Compact(&buf, []byte(toolCall.Function.Arguments)); err == nil { - input = json.RawMessage(buf.Bytes()) + input = buf.Bytes() } else { // Preserve original payload instead of silently dropping args. input = json.RawMessage([]byte(toolCall.Function.Arguments)) @@ -1459,22 +1394,30 @@ func bedrockExtractFloat64(v interface{}) (float64, bool) { // tryParseJSONIntoContentBlock try to parse input text into a JSON and returns a proper // BedrockContentBlock based on the result. func tryParseJSONIntoContentBlock(text string) BedrockContentBlock { - var parsed interface{} - // Try to parse as JSON, otherwise treat as text - if err := sonic.UnmarshalString(text, &parsed); err != nil { + // Validate and compact JSON without parsing into Go types (preserves key ordering) + var buf bytes.Buffer + if err := json.Compact(&buf, []byte(text)); err != nil { return BedrockContentBlock{Text: schemas.Ptr(text)} + } + compacted := buf.Bytes() + + // Bedrock does not accept primitives or arrays directly in the json field + if len(compacted) > 0 && compacted[0] == '{' { + // Objects are valid as-is + return BedrockContentBlock{JSON: json.RawMessage(compacted)} + } else if len(compacted) > 0 && compacted[0] == '[' { + // Arrays need to be wrapped + wrapped := make([]byte, 0, len(compacted)+len(`{"results":}`)) + wrapped = append(wrapped, `{"results":`...) + wrapped = append(wrapped, compacted...) + wrapped = append(wrapped, '}') + return BedrockContentBlock{JSON: json.RawMessage(wrapped)} } else { - // Bedrock does not accept primitives or arrays directly in the json field - switch v := parsed.(type) { - case map[string]any: - // Objects are valid as-is - return BedrockContentBlock{JSON: v} - case []any: - // Arrays need to be wrapped - return BedrockContentBlock{JSON: map[string]any{"results": v}} - default: - // Primitives (string, number, boolean, null) need to be wrapped - return BedrockContentBlock{JSON: map[string]any{"value": v}} - } + // Primitives (string, number, boolean, null) need to be wrapped + wrapped := make([]byte, 0, len(compacted)+len(`{"value":}`)) + wrapped = append(wrapped, `{"value":`...) + wrapped = append(wrapped, compacted...) + wrapped = append(wrapped, '}') + return BedrockContentBlock{JSON: json.RawMessage(wrapped)} } } diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index fb399731ca..83ffd8752f 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -1,6 +1,7 @@ package cohere import ( + "encoding/json" "slices" "github.com/maximhq/bifrost/core/schemas" @@ -26,7 +27,7 @@ func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} { type CohereRerankResult struct { Index int `json:"index"` RelevanceScore float64 `json:"relevance_score"` - Document map[string]interface{} `json:"document,omitempty"` + Document json.RawMessage `json:"document,omitempty"` } // CohereRerankResponse represents a Cohere rerank API response. diff --git a/core/providers/cohere/rerank.go b/core/providers/cohere/rerank.go index 7f8b2ef70f..b820e3796b 100644 --- a/core/providers/cohere/rerank.go +++ b/core/providers/cohere/rerank.go @@ -3,6 +3,7 @@ package cohere import ( "sort" + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "gopkg.in/yaml.v3" @@ -92,25 +93,43 @@ func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schema } // Convert document if present - if result.Document != nil { - doc := &schemas.RerankDocument{} - if text, ok := result.Document["text"].(string); ok { - doc.Text = text - } - if id, ok := result.Document["id"].(string); ok { - doc.ID = &id - } - // Collect remaining fields as meta - meta := make(map[string]interface{}) - for k, v := range result.Document { - if k != "text" && k != "id" { - meta[k] = v + if len(result.Document) > 0 { + var docMap map[string]interface{} + if err := sonic.Unmarshal(result.Document, &docMap); err == nil { + doc := &schemas.RerankDocument{} + populated := false + if text, ok := docMap["text"].(string); ok { + doc.Text = text + populated = true + } + if id, ok := docMap["id"].(string); ok { + doc.ID = &id + populated = true + } + // Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting + meta := make(map[string]interface{}) + if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok { + for k, v := range rawMeta { + meta[k] = v + } + } else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok { + for k, v := range rawMeta { + meta[k] = v + } + } + for k, v := range docMap { + if k != "text" && k != "id" && k != "metadata" && k != "meta" { + meta[k] = v + } + } + if len(meta) > 0 { + doc.Meta = meta + populated = true + } + if populated { + rerankResult.Document = doc } } - if len(meta) > 0 { - doc.Meta = meta - } - rerankResult.Document = doc } bifrostResponse.Results = append(bifrostResponse.Results, rerankResult) diff --git a/core/providers/cohere/rerank_test.go b/core/providers/cohere/rerank_test.go index 526c65b6de..38caeac816 100644 --- a/core/providers/cohere/rerank_test.go +++ b/core/providers/cohere/rerank_test.go @@ -1,6 +1,7 @@ package cohere import ( + "encoding/json" "testing" "github.com/maximhq/bifrost/core/schemas" @@ -15,18 +16,12 @@ func TestCohereRerankResponseToBifrostRerankResponse(t *testing.T) { { Index: 1, RelevanceScore: 0.62, - Document: map[string]interface{}{ - "text": "provider-doc-1", - "id": "doc-1", - "topic": "geography", - }, + Document: json.RawMessage(`{"text":"provider-doc-1","id":"doc-1","topic":"geography"}`), }, { Index: 0, RelevanceScore: 0.91, - Document: map[string]interface{}{ - "text": "provider-doc-0", - }, + Document: json.RawMessage(`{"text":"provider-doc-0"}`), }, }, }).ToBifrostRerankResponse(nil, false) @@ -56,16 +51,12 @@ func TestCohereRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T { Index: 1, RelevanceScore: 0.62, - Document: map[string]interface{}{ - "text": "provider-doc-1", - }, + Document: json.RawMessage(`{"text":"provider-doc-1"}`), }, { Index: 0, RelevanceScore: 0.91, - Document: map[string]interface{}{ - "text": "provider-doc-0", - }, + Document: json.RawMessage(`{"text":"provider-doc-0"}`), }, }, }).ToBifrostRerankResponse(requestDocs, true) diff --git a/core/providers/cohere/responses.go b/core/providers/cohere/responses.go index 2929592afb..3c4d855b01 100644 --- a/core/providers/cohere/responses.go +++ b/core/providers/cohere/responses.go @@ -9,6 +9,7 @@ import ( "github.com/maximhq/bifrost/core/providers/anthropic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + "github.com/tidwall/gjson" ) // CohereResponsesStreamState tracks state during streaming conversion for responses API @@ -878,16 +879,21 @@ func (chunk *CohereStreamEvent) ToBifrostResponsesStream(sequenceNumber int, sta } if source.Document != nil { - if title, ok := (*source.Document)["title"].(string); ok { + doc := []byte(*source.Document) + if t := providerUtils.GetJSONField(doc, "title"); t.Exists() && t.Type == gjson.String { + title := t.String() annotation.Title = &title } - if id, ok := (*source.Document)["id"].(string); ok && annotation.FileID == nil { - annotation.FileID = &id + if id := providerUtils.GetJSONField(doc, "id"); id.Exists() && id.Type == gjson.String && annotation.FileID == nil { + idStr := id.String() + annotation.FileID = &idStr } - if snippet, ok := (*source.Document)["snippet"].(string); ok { + if s := providerUtils.GetJSONField(doc, "snippet"); s.Exists() && s.Type == gjson.String { + snippet := s.String() annotation.Text = &snippet } - if url, ok := (*source.Document)["url"].(string); ok { + if u := providerUtils.GetJSONField(doc, "url"); u.Exists() && u.Type == gjson.String { + url := u.String() annotation.URL = &url } } diff --git a/core/providers/cohere/types.go b/core/providers/cohere/types.go index 3af7823a8c..06da064902 100644 --- a/core/providers/cohere/types.go +++ b/core/providers/cohere/types.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/bytedance/sonic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -79,10 +80,10 @@ type CohereMessageContent struct { // MarshalJSON implements custom JSON marshaling for CohereMessageContent func (c *CohereMessageContent) MarshalJSON() ([]byte, error) { if c.StringContent != nil { - return json.Marshal(*c.StringContent) + return providerUtils.MarshalSorted(*c.StringContent) } if c.BlocksContent != nil { - return json.Marshal(c.BlocksContent) + return providerUtils.MarshalSorted(c.BlocksContent) } return []byte("null"), nil } @@ -91,14 +92,14 @@ func (c *CohereMessageContent) MarshalJSON() ([]byte, error) { func (c *CohereMessageContent) UnmarshalJSON(data []byte) error { // Try to unmarshal as string first var str string - if err := json.Unmarshal(data, &str); err == nil { + if err := sonic.Unmarshal(data, &str); err == nil { c.StringContent = &str return nil } // Try to unmarshal as content blocks array var blocks []CohereContentBlock - if err := json.Unmarshal(data, &blocks); err == nil { + if err := sonic.Unmarshal(data, &blocks); err == nil { c.BlocksContent = blocks return nil } @@ -420,8 +421,8 @@ type CohereCitation struct { type CohereSource struct { Type CohereSourceType `json:"type"` // Source type ("tool" or "document") ID *string `json:"id,omitempty"` // Source ID (nullable) - ToolOutput *map[string]any `json:"tool_output,omitempty"` // Tool output (for tool sources) - Document *map[string]any `json:"document,omitempty"` // Document data (for document sources) + ToolOutput *json.RawMessage `json:"tool_output,omitempty"` // Tool output (for tool sources, json.RawMessage preserves key ordering) + Document *json.RawMessage `json:"document,omitempty"` // Document data (for document sources, json.RawMessage preserves key ordering) } // ==================== STREAMING TYPES ==================== @@ -467,12 +468,12 @@ type CohereStreamToolCallStruct struct { // JSON marshaling for CohereStreamToolCall func (c *CohereStreamToolCallStruct) MarshalJSON() ([]byte, error) { if c.CohereToolCallObject != nil { - return sonic.Marshal(c.CohereToolCallObject) + return providerUtils.MarshalSorted(c.CohereToolCallObject) } if c.CohereToolCallArray != nil { - return sonic.Marshal(c.CohereToolCallArray) + return providerUtils.MarshalSorted(c.CohereToolCallArray) } - return sonic.Marshal(nil) + return providerUtils.MarshalSorted(nil) } func (c *CohereStreamToolCallStruct) UnmarshalJSON(data []byte) error { @@ -503,12 +504,12 @@ type CohereStreamContentStruct struct { func (c *CohereStreamContentStruct) MarshalJSON() ([]byte, error) { if c.CohereStreamContentObject != nil { - return sonic.Marshal(c.CohereStreamContentObject) + return providerUtils.MarshalSorted(c.CohereStreamContentObject) } if c.CohereStreamContentArray != nil { - return sonic.Marshal(c.CohereStreamContentArray) + return providerUtils.MarshalSorted(c.CohereStreamContentArray) } - return sonic.Marshal(nil) + return providerUtils.MarshalSorted(nil) } func (c *CohereStreamContentStruct) UnmarshalJSON(data []byte) error { @@ -539,12 +540,12 @@ type CohereStreamCitationStruct struct { func (c *CohereStreamCitationStruct) MarshalJSON() ([]byte, error) { if c.CohereStreamCitationObject != nil { - return sonic.Marshal(c.CohereStreamCitationObject) + return providerUtils.MarshalSorted(c.CohereStreamCitationObject) } if c.CohereStreamCitationArray != nil { - return sonic.Marshal(c.CohereStreamCitationArray) + return providerUtils.MarshalSorted(c.CohereStreamCitationArray) } - return sonic.Marshal(nil) + return providerUtils.MarshalSorted(nil) } func (c *CohereStreamCitationStruct) UnmarshalJSON(data []byte) error { diff --git a/core/providers/gemini/responses.go b/core/providers/gemini/responses.go index 8cbd600cf0..e9f54598f1 100644 --- a/core/providers/gemini/responses.go +++ b/core/providers/gemini/responses.go @@ -9,7 +9,6 @@ import ( "sync" "time" - "github.com/bytedance/sonic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -339,7 +338,12 @@ func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *G responseMap := make(map[string]any) if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { - responseMap["output"] = *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + output := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + if json.Valid([]byte(output)) { + responseMap["output"] = json.RawMessage(output) + } else { + responseMap["output"] = output + } } funcName := "" if msg.ResponsesToolMessage.Name != nil && strings.TrimSpace(*msg.ResponsesToolMessage.Name) != "" { @@ -348,9 +352,10 @@ func ToGeminiResponsesResponse(bifrostResp *schemas.BifrostResponsesResponse) *G funcName = *msg.ResponsesToolMessage.CallID } + responseBytes, _ := providerUtils.MarshalSorted(responseMap) functionResponse := &FunctionResponse{ Name: funcName, - Response: responseMap, + Response: json.RawMessage(responseBytes), } if msg.ResponsesToolMessage.CallID != nil { functionResponse.ID = *msg.ResponsesToolMessage.CallID @@ -1917,13 +1922,13 @@ func convertGeminiContentsToResponsesMessages(contents []Content) []schemas.Resp } } - // Convert response map to string + // Convert response to string — extract output field if present responseStr := "" if part.FunctionResponse.Response != nil { - if output, ok := part.FunctionResponse.Response["output"].(string); ok { - responseStr = output - } else if responseBytes, err := sonic.Marshal(part.FunctionResponse.Response); err == nil { - responseStr = string(responseBytes) + if r := providerUtils.GetJSONField(part.FunctionResponse.Response, "output"); r.Exists() { + responseStr = r.String() + } else { + responseStr = string(part.FunctionResponse.Response) } } @@ -3003,10 +3008,20 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag // Extract output from ResponsesToolMessage.Output if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { - responseMap["output"] = *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + output := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + if json.Valid([]byte(output)) { + responseMap["output"] = json.RawMessage(output) + } else { + responseMap["output"] = output + } } else if msg.Content != nil && msg.Content.ContentStr != nil { // Fallback to Content.ContentStr for backward compatibility - responseMap["output"] = *msg.Content.ContentStr + output := *msg.Content.ContentStr + if json.Valid([]byte(output)) { + responseMap["output"] = json.RawMessage(output) + } else { + responseMap["output"] = output + } } // Prefer the declared tool name; fallback to callIDToName lookup, then raw CallID @@ -3019,10 +3034,11 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag funcName = *msg.ResponsesToolMessage.CallID } + responseBytes, _ := providerUtils.MarshalSorted(responseMap) part := &Part{ FunctionResponse: &FunctionResponse{ Name: funcName, - Response: responseMap, + Response: json.RawMessage(responseBytes), ID: *msg.ResponsesToolMessage.CallID, }, } diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index acd758e185..2464a9913e 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -13,6 +13,7 @@ import ( "cloud.google.com/go/civil" "github.com/bytedance/sonic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -297,7 +298,7 @@ func (i *Interval) UnmarshalJSON(data []byte) error { Alias: (*Alias)(i), } - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } @@ -335,7 +336,7 @@ func (i *Interval) MarshalJSON() ([]byte, error) { aux.EndTime = (*time.Time)(&i.EndTime) } - return json.Marshal(aux) + return providerUtils.MarshalSorted(aux) } // GoogleSearch is a tool to support Google Search in Model. Powered by Google. @@ -360,7 +361,7 @@ func (g *GoogleSearch) UnmarshalJSON(data []byte) error { Alias: (*Alias)(g), } - if err := json.Unmarshal(data, aux); err != nil { + if err := sonic.Unmarshal(data, aux); err != nil { return err } @@ -747,7 +748,7 @@ func (t *Tool) UnmarshalJSON(data []byte) error { Alias: (*Alias)(t), } - if err := json.Unmarshal(data, aux); err != nil { + if err := sonic.Unmarshal(data, aux); err != nil { return err } @@ -995,7 +996,7 @@ func (p *PrebuiltVoiceConfig) UnmarshalJSON(data []byte) error { VoiceName string `json:"voice_name,omitempty"` } var aux Alias - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } p.VoiceName = aux.VoiceName @@ -1008,7 +1009,7 @@ func (p PrebuiltVoiceConfig) MarshalJSON() ([]byte, error) { type Alias struct { VoiceName string `json:"voiceName,omitempty"` } - return json.Marshal(Alias(p)) + return providerUtils.MarshalSorted(Alias(p)) } // VoiceConfig represents the configuration for the voice to use. @@ -1024,7 +1025,7 @@ func (v *VoiceConfig) UnmarshalJSON(data []byte) error { PrebuiltVoiceConfig *PrebuiltVoiceConfig `json:"prebuilt_voice_config,omitempty"` } var aux Alias - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } v.PrebuiltVoiceConfig = aux.PrebuiltVoiceConfig @@ -1037,7 +1038,7 @@ func (v VoiceConfig) MarshalJSON() ([]byte, error) { type Alias struct { PrebuiltVoiceConfig *PrebuiltVoiceConfig `json:"prebuiltVoiceConfig,omitempty"` } - return json.Marshal(Alias(v)) + return providerUtils.MarshalSorted(Alias(v)) } // SpeakerVoiceConfig represents the configuration for the speaker to use. @@ -1201,7 +1202,7 @@ func (p *Part) UnmarshalJSON(data []byte) error { } var aux PartAlias - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } @@ -1255,7 +1256,7 @@ func (b *Blob) UnmarshalJSON(data []byte) error { } var aux BlobAlias - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } @@ -1373,7 +1374,7 @@ type FunctionResponse struct { // Required. The function response in JSON object format. Use "output" key to specify // function output and "error" key to specify error details (if any). If "output" and // "error" keys are not specified, then whole "response" is treated as function output. - Response map[string]any `json:"response,omitempty"` + Response json.RawMessage `json:"response,omitempty"` } // ==================== RESPONSE TYPES ==================== @@ -1479,7 +1480,7 @@ type dateJSON civil.Date func (d *dateJSON) UnmarshalJSON(data []byte) error { m := make(map[string]int) - if err := json.Unmarshal(data, &m); err != nil { + if err := sonic.Unmarshal(data, &m); err != nil { return fmt.Errorf("failed to unmarshal date from map: %w", err) } @@ -1503,7 +1504,7 @@ func (d *dateJSON) UnmarshalJSON(data []byte) error { func (d *dateJSON) MarshalJSON() ([]byte, error) { m := make(map[string]int) if d == nil || (civil.Date)(*d).IsZero() { - return json.Marshal(nil) + return providerUtils.MarshalSorted(nil) } if d.Year != 0 { m["year"] = d.Year @@ -1514,7 +1515,7 @@ func (d *dateJSON) MarshalJSON() ([]byte, error) { if d.Day != 0 { m["day"] = d.Day } - return json.Marshal(m) + return providerUtils.MarshalSorted(m) } func (c *Citation) UnmarshalJSON(data []byte) error { @@ -1526,7 +1527,7 @@ func (c *Citation) UnmarshalJSON(data []byte) error { Alias: (*Alias)(c), } - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } @@ -1550,7 +1551,7 @@ func (c *Citation) MarshalJSON() ([]byte, error) { aux.PublicationDate = (*dateJSON)(&c.PublicationDate) } - return json.Marshal(aux) + return providerUtils.MarshalSorted(aux) } // Citation information when the model quotes another source. @@ -1854,7 +1855,7 @@ func (g *GenerateContentResponse) UnmarshalJSON(data []byte) error { Alias: (*Alias)(g), } - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } @@ -1878,7 +1879,7 @@ func (g *GenerateContentResponse) MarshalJSON() ([]byte, error) { aux.CreateTime = (*time.Time)(&g.CreateTime) } - return json.Marshal(aux) + return providerUtils.MarshalSorted(aux) } type GeminiGenerationError struct { @@ -1973,7 +1974,7 @@ type GeminiBatchStats struct { // MarshalJSON implements the json.Marshaler interface. func (g *GeminiBatchStats) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { + return providerUtils.MarshalSorted(struct { RequestCount string `json:"requestCount"` PendingRequestCount string `json:"pendingRequestCount"` SuccessfulRequestCount string `json:"successfulRequestCount"` @@ -1992,7 +1993,7 @@ func (g *GeminiBatchStats) UnmarshalJSON(data []byte) error { PendingRequestCount string `json:"pendingRequestCount"` SuccessfulRequestCount string `json:"successfulRequestCount"` } - if err := json.Unmarshal(data, &raw); err != nil { + if err := sonic.Unmarshal(data, &raw); err != nil { return err } if raw.RequestCount != "" { @@ -2344,7 +2345,7 @@ func (v *VideoImageData) UnmarshalJSON(data []byte) error { } var aux VideoImageDataAlias - if err := json.Unmarshal(data, &aux); err != nil { + if err := sonic.Unmarshal(data, &aux); err != nil { return err } @@ -2502,7 +2503,7 @@ type HTTPOptions struct { // The structure must match the backend API's request structure. // - VertexAI backend API docs: https://cloud.google.com/vertex-ai/docs/reference/rest // - GeminiAPI backend API docs: https://ai.google.dev/api/rest - ExtraBody map[string]any `json:"extraBody,omitempty"` + ExtraBody json.RawMessage `json:"extraBody,omitempty"` // Optional. A function that allows for request body customization. // It is executed after ExtraBody has been merged, offering more advanced // control over the request body than the static ExtraBody. @@ -2611,12 +2612,12 @@ type GenerateVideosOperation struct { // progress information and common metadata such as create time. Some services might // not provide such metadata. Any method that returns a long-running operation should // document the metadata type, if any. - Metadata map[string]any `json:"metadata,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` // If the value is `false`, it means the operation is still in progress. If `true`, // the operation is completed, and either `error` or `response` is available. Done bool `json:"done,omitempty"` // Optional. The error result of the operation in case of failure or cancellation. - Error map[string]any `json:"error,omitempty"` + Error json.RawMessage `json:"error,omitempty"` // Optional. The long-running operation response payload. Response *GenerateVideosOperationResponse `json:"response,omitempty"` } diff --git a/core/providers/gemini/utils.go b/core/providers/gemini/utils.go index 98e24d4528..907ea8aa8e 100644 --- a/core/providers/gemini/utils.go +++ b/core/providers/gemini/utils.go @@ -368,7 +368,7 @@ func convertSchemaToOrderedMap(schema *Schema) *schemas.OrderedMap { func convertSchemaToMap(schema *Schema) *schemas.OrderedMap { // Convert map[string]*Schema to map[string]interface{} using JSON marshaling - data, err := sonic.Marshal(schema.Properties) + data, err := providerUtils.MarshalSorted(schema.Properties) if err != nil { return schemas.NewOrderedMap() } @@ -1528,7 +1528,7 @@ func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) ([]Content, // must be sent in a single message with only functionResponse parts (no text parts) if isToolResponse { // Parse the response content - var responseData map[string]any + var responseData json.RawMessage var contentStr string if message.Content != nil { @@ -1549,18 +1549,21 @@ func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) ([]Content, } } - // Try to unmarshal as JSON + // Try to use raw JSON if it's a valid JSON object (Gemini requires Struct/object) if contentStr != "" { - err := sonic.Unmarshal([]byte(contentStr), &responseData) - if err != nil { - // If unmarshaling fails, wrap the original string to preserve it - responseData = map[string]any{ + var buf bytes.Buffer + if err := json.Compact(&buf, []byte(contentStr)); err == nil && buf.Len() > 0 && buf.Bytes()[0] == '{' { + // Valid JSON object — use raw bytes directly + responseData = json.RawMessage(buf.Bytes()) + } else { + // Not valid JSON or not an object — wrap to preserve content + responseData, _ = providerUtils.MarshalSorted(map[string]any{ "content": contentStr, - } + }) } } else { - // If no content at all, use empty map to avoid nil - responseData = map[string]any{} + // If no content at all, use empty object to avoid nil + responseData = json.RawMessage(`{}`) } // Use ToolCallID if available, ensuring it's not nil @@ -2091,7 +2094,7 @@ func buildOpenAIResponseFormat(responseJsonSchema interface{}, responseSchema *S } } else if responseSchema != nil { // Convert responseSchema to map using JSON marshaling and type normalization - data, err := sonic.Marshal(responseSchema) + data, err := providerUtils.MarshalSorted(responseSchema) if err != nil { // If marshaling fails, fall back to json_object mode return &schemas.ResponsesTextConfig{ @@ -2323,18 +2326,19 @@ func extractFunctionResponseOutput(funcResp *FunctionResponse) string { } // Try to extract "output" field first - if outputVal, ok := funcResp.Response["output"]; ok { - if outputStr, ok := outputVal.(string); ok { - return outputStr + var respMap map[string]json.RawMessage + if err := sonic.Unmarshal(funcResp.Response, &respMap); err == nil { + if outputVal, ok := respMap["output"]; ok { + var outputStr string + if err := sonic.Unmarshal(outputVal, &outputStr); err == nil { + return outputStr + } + return string(outputVal) } } - // If no "output" key, marshal the entire response - if jsonResponse, err := sonic.Marshal(funcResp.Response); err == nil { - return string(jsonResponse) - } - - return "" + // If no "output" key or unmarshal failed, return raw JSON + return string(funcResp.Response) } // decodeBase64StringToBytes decodes a base64-encoded string into raw bytes. diff --git a/core/providers/gemini/videos.go b/core/providers/gemini/videos.go index 699c226003..62ce110c26 100644 --- a/core/providers/gemini/videos.go +++ b/core/providers/gemini/videos.go @@ -2,6 +2,7 @@ package gemini import ( "encoding/base64" + "encoding/json" "fmt" "net/url" "strconv" @@ -180,7 +181,7 @@ func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe if referenceImages, ok := bifrostReq.Params.ExtraParams["referenceImages"]; ok { if referenceImages, ok := referenceImages.([]VideoReferenceImage); ok && referenceImages != nil { params.ReferenceImages = referenceImages - } else if data, err := sonic.Marshal(referenceImages); err == nil { + } else if data, err := providerUtils.MarshalSorted(referenceImages); err == nil { var referenceImages []VideoReferenceImage if sonic.Unmarshal(data, &referenceImages) == nil { params.ReferenceImages = referenceImages @@ -190,7 +191,7 @@ func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe if lastFrame, ok := bifrostReq.Params.ExtraParams["lastFrame"]; ok { if lastFrame, ok := lastFrame.(*VideoImageData); ok { params.LastFrame = lastFrame - } else if data, err := sonic.Marshal(lastFrame); err == nil { + } else if data, err := providerUtils.MarshalSorted(lastFrame); err == nil { var lastFrame VideoImageData if sonic.Unmarshal(data, &lastFrame) == nil { params.LastFrame = &lastFrame @@ -232,15 +233,24 @@ func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model if !operation.Done { response.Status = schemas.VideoStatusInProgress if operation.Metadata != nil { - if progress, ok := operation.Metadata["progress"].(float64); ok { + if p := providerUtils.GetJSONField([]byte(operation.Metadata), "progress"); p.Exists() { + progress := p.Float() response.Progress = &progress } } } else if operation.Error != nil { response.Status = schemas.VideoStatusFailed + code := providerUtils.GetJSONField(operation.Error, "code").String() + message := providerUtils.GetJSONField(operation.Error, "message").String() + if code == "" { + code = "video_generation_failed" + } + if message == "" { + message = string(operation.Error) + } response.Error = &schemas.VideoCreateError{ - Code: "video_generation_failed", - Message: fmt.Sprintf("%v", operation.Error), + Code: code, + Message: message, } } else if operation.Response != nil { // Check new response format with content filtering support @@ -358,13 +368,13 @@ func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model // Try to extract timestamps from metadata if operation.Metadata != nil { - if createTime, ok := operation.Metadata["createTime"].(string); ok { - if t, err := time.Parse(time.RFC3339, createTime); err == nil { + if ct := providerUtils.GetJSONField([]byte(operation.Metadata), "createTime"); ct.Exists() { + if t, err := time.Parse(time.RFC3339, ct.String()); err == nil { response.CreatedAt = t.Unix() } } - if updateTime, ok := operation.Metadata["updateTime"].(string); ok { - if t, err := time.Parse(time.RFC3339, updateTime); err == nil && operation.Done { + if ut := providerUtils.GetJSONField([]byte(operation.Metadata), "updateTime"); ut.Exists() { + if t, err := time.Parse(time.RFC3339, ut.String()); err == nil && operation.Done { response.CompletedAt = schemas.Ptr(t.Unix()) } } @@ -571,10 +581,11 @@ func ToGeminiVideoGenerationResponse(response *schemas.BifrostVideoGenerationRes }, } } else if response.Error != nil { - operation.Error = map[string]any{ + errBytes, _ := providerUtils.MarshalSorted(map[string]any{ "message": response.Error.Message, "code": response.Error.Code, - } + }) + operation.Error = json.RawMessage(errBytes) } default: operation.Done = false diff --git a/core/providers/replicate/types.go b/core/providers/replicate/types.go index 6251aeebcc..98f84e613e 100644 --- a/core/providers/replicate/types.go +++ b/core/providers/replicate/types.go @@ -1,6 +1,7 @@ package replicate import ( + "encoding/json" "fmt" "time" @@ -86,7 +87,7 @@ func (r *ReplicatePredictionRequestInput) MarshalJSON() ([]byte, error) { type Alias ReplicatePredictionRequestInput // Marshal the struct normally (ExtraParams will be omitted due to json:"-" tag) - aliasData, err := sonic.Marshal((*Alias)(r)) + aliasData, err := providerUtils.MarshalSorted((*Alias)(r)) if err != nil { return nil, err } @@ -194,7 +195,7 @@ type ReplicatePredictionResponse struct { ID string `json:"id"` Model string `json:"model"` // Model identifier (owner/name or owner/name:version) Version string `json:"version"` // Model version ID - Input map[string]interface{} `json:"input"` // Input parameters used + Input json.RawMessage `json:"input"` // Input parameters used (json.RawMessage preserves key ordering) Output *ReplicateOutput `json:"output,omitempty"` // Output data (can be various types) Logs *string `json:"logs,omitempty"` // Execution logs Error *string `json:"error,omitempty"` // Error message if failed @@ -239,16 +240,16 @@ func (mc ReplicateOutput) MarshalJSON() ([]byte, error) { } if mc.OutputStr != nil { - return sonic.Marshal(*mc.OutputStr) + return providerUtils.MarshalSorted(*mc.OutputStr) } if mc.OutputArray != nil { - return sonic.Marshal(mc.OutputArray) + return providerUtils.MarshalSorted(mc.OutputArray) } if mc.OutputObject != nil { - return sonic.Marshal(mc.OutputObject) + return providerUtils.MarshalSorted(mc.OutputObject) } // If all are nil, return null - return sonic.Marshal(nil) + return providerUtils.MarshalSorted(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ReplicateOutput. @@ -322,7 +323,7 @@ type ReplicateModelResponse struct { LicenseURL *string `json:"license_url,omitempty"` // License URL RunCount *int `json:"run_count,omitempty"` // Number of times run CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL - DefaultExample *map[string]interface{} `json:"default_example,omitempty"` // Default example prediction + DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details } @@ -332,7 +333,7 @@ type ReplicateModelVersion struct { ID string `json:"id"` // Version ID CreatedAt string `json:"created_at"` // ISO 8601 timestamp CogVersion *string `json:"cog_version,omitempty"` // Cog version used - OpenAPISchema map[string]interface{} `json:"openapi_schema,omitempty"` // OpenAPI schema for the model + OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID } @@ -409,7 +410,7 @@ type ReplicateWebhookPayload struct { ID string `json:"id"` Model string `json:"model"` Version string `json:"version"` - Input map[string]interface{} `json:"input"` + Input json.RawMessage `json:"input"` Output interface{} `json:"output,omitempty"` Logs *string `json:"logs,omitempty"` Error *string `json:"error,omitempty"` @@ -482,7 +483,7 @@ type ReplicateFileResponse struct { ContentType string `json:"content_type"` // MIME type CreatedAt string `json:"created_at"` // ISO 8601 timestamp ExpiresAt string `json:"expires_at,omitempty"` // ISO 8601 timestamp - Metadata map[string]interface{} `json:"metadata,omitempty"` // User-provided metadata + Metadata json.RawMessage `json:"metadata,omitempty"` // User-provided metadata (json.RawMessage preserves key ordering) Name string `json:"name,omitempty"` // File name Size int64 `json:"size"` // File size in bytes URLs *ReplicateFileURLs `json:"urls,omitempty"` // Associated URLs diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index 69bb4d51a6..a08d9fbed8 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -66,6 +66,12 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) } + // Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305) + jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + } + // Add anthropic_version if not present if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) @@ -74,6 +80,13 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s } } } else { + // Validate tools are supported by Vertex + if request.Params != nil && request.Params.Tools != nil { + if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil { + return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil, providerName) + } + } + // Convert request to Anthropic format reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { @@ -90,6 +103,9 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s reqBody.SetStripCacheControlScope(true) + // Add provider-aware beta headers + anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) + // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { @@ -104,6 +120,20 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s } } + // Inject beta headers into body as anthropic_beta (Vertex uses body field, not HTTP header) + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + betaHeaders, betaErr := anthropic.FilterBetaHeadersForProvider(extraHeaders["anthropic-beta"], schemas.Vertex) + if betaErr != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, betaErr, providerName) + } + if len(betaHeaders) > 0 { + jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + } + } + } + if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index e803acf673..b2455a606a 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -400,6 +400,8 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } extraParams = reqBody.GetExtraParams() reqBody.Model = deployment + // Add provider-aware beta headers for Vertex + anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) // Marshal to JSON bytes, preserving struct field order rawBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { @@ -412,6 +414,19 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("failed to set anthropic_version: %w", err) } } + // Inject beta headers into body as anthropic_beta (Vertex uses body field, not HTTP header) + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + betaHeaders, betaErr := anthropic.FilterBetaHeadersForProvider(extraHeaders["anthropic-beta"], schemas.Vertex) + if betaErr != nil { + return nil, fmt.Errorf("unsupported beta header: %w", betaErr) + } + if len(betaHeaders) > 0 { + rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_beta", betaHeaders) + if err != nil { + return nil, fmt.Errorf("failed to set anthropic_beta: %w", err) + } + } + } // Remove model field (it's in URL for Vertex) rawBody, err = providerUtils.DeleteJSONField(rawBody, "model") if err != nil { @@ -468,6 +483,15 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) } + // Remap unsupported tool versions for Vertex (handles raw passthrough bodies) + if schemas.IsAnthropicModel(deployment) && jsonBody != nil { + remappedBody, remapErr := anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) + if remapErr != nil { + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + } + jsonBody = remappedBody + } + // Auth query is used for fine-tuned models to pass the API key in the query string authQuery := "" // Determine the URL based on model type @@ -745,6 +769,8 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext extraParams = reqBody.GetExtraParams() reqBody.Model = deployment reqBody.Stream = schemas.Ptr(true) + // Add provider-aware beta headers for Vertex + anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) // Marshal to JSON bytes, preserving struct field order for prompt caching rawBody, err := providerUtils.MarshalSorted(reqBody) @@ -759,6 +785,19 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("failed to set anthropic_version: %w", err) } } + // Inject beta headers into body as anthropic_beta (Vertex uses body field, not HTTP header) + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + betaHeaders, betaErr := anthropic.FilterBetaHeadersForProvider(extraHeaders["anthropic-beta"], schemas.Vertex) + if betaErr != nil { + return nil, fmt.Errorf("unsupported beta header: %w", betaErr) + } + if len(betaHeaders) > 0 { + rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_beta", betaHeaders) + if err != nil { + return nil, fmt.Errorf("failed to set anthropic_beta: %w", err) + } + } + } // Remove model and region fields (using sjson to preserve order) rawBody, err = providerUtils.DeleteJSONField(rawBody, "model") @@ -776,6 +815,15 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, bifrostErr } + // Remap unsupported tool versions for Vertex streaming (handles raw passthrough bodies) + if jsonData != nil { + var remapErr error + jsonData, remapErr = anthropic.RemapRawToolVersionsForProvider(jsonData, schemas.Vertex) + if remapErr != nil { + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + } + } + var completeURL string if region == "global" { completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) diff --git a/ui/public/images/google-workspace.png b/ui/public/images/google-workspace.png new file mode 100644 index 0000000000..a7443a0eeb Binary files /dev/null and b/ui/public/images/google-workspace.png differ diff --git a/ui/public/images/keycloak.png b/ui/public/images/keycloak.png new file mode 100644 index 0000000000..48e18430cb Binary files /dev/null and b/ui/public/images/keycloak.png differ diff --git a/ui/public/images/sailpoint.png b/ui/public/images/sailpoint.png new file mode 100644 index 0000000000..a3813d93f6 Binary files /dev/null and b/ui/public/images/sailpoint.png differ diff --git a/ui/public/images/zitadel.png b/ui/public/images/zitadel.png new file mode 100644 index 0000000000..b8ede7439c Binary files /dev/null and b/ui/public/images/zitadel.png differ