diff --git a/core/changelog.md b/core/changelog.md index 3431d11bc8..45b09cb9eb 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -2,3 +2,4 @@ - fix: case-insensitive `anthropic-beta` merge in `MergeBetaHeaders` - fix: Bedrock provider - emit message_stop event for Anthropic invoke stream [@tefimov](https://github.com/tefimov) - fix: gemini preserves thinkingLevel parameters during round-trip and finish reason mapping +- fix: WebSearch tool argument handling for all clients by removing the Claude Code user agent restriction diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index 5ec1c17a80..ddc09fd912 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -77,12 +77,28 @@ var anthropicResponsesStreamStatePool = sync.Pool{ }, } -// webSearchItemIDs tracks item IDs for WebSearch tools to skip their argument deltas -// Maps item_id (string) -> true for WebSearch tools that need delta skipping -var webSearchItemIDs sync.Map +// anthropicToResponsesStreamState holds per-request state for the Bifrost→Anthropic +// stream conversion direction. +type anthropicToResponsesStreamState struct { + // webSearchItemIDs tracks item IDs for WebSearch tools so their argument deltas + // can be skipped and regenerated synthetically (with sanitization) at output_item.done. + webSearchItemIDs map[string]bool +} + +type anthropicToResponsesStreamStateKeyType struct{} + +var anthropicToResponsesStreamStateKey = anthropicToResponsesStreamStateKeyType{} -// webFetchItemIDs tracks item IDs for WebFetch tools to skip their argument deltas -var webFetchItemIDs sync.Map +// getOrCreateAnthropicToResponsesStreamState returns the per-request conversion state, +// creating and storing it in ctx on first access. +func getOrCreateAnthropicToResponsesStreamState(ctx *schemas.BifrostContext) *anthropicToResponsesStreamState { + if v := ctx.Value(anthropicToResponsesStreamStateKey); v != nil { + return v.(*anthropicToResponsesStreamState) + } + state := &anthropicToResponsesStreamState{} + ctx.SetValue(anthropicToResponsesStreamStateKey, state) + return state +} // acquireAnthropicResponsesStreamState gets an Anthropic responses stream state from the pool. func acquireAnthropicResponsesStreamState() *AnthropicResponsesStreamState { @@ -1580,10 +1596,15 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp contentBlock.Input = json.RawMessage("{}") // Track WebSearch tools so we can skip their argument deltas + // and regenerate them synthetically (with sanitization) at output_item.done if bifrostResp.Item.ResponsesToolMessage.Name != nil && *bifrostResp.Item.ResponsesToolMessage.Name == "WebSearch" && bifrostResp.Item.ID != nil { - webSearchItemIDs.Store(*bifrostResp.Item.ID, true) + streamState := getOrCreateAnthropicToResponsesStreamState(ctx) + if streamState.webSearchItemIDs == nil { + streamState.webSearchItemIDs = make(map[string]bool) + } + streamState.webSearchItemIDs[*bifrostResp.Item.ID] = true } } } @@ -1691,12 +1712,10 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp } case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: - // Skip WebSearch/WebFetch tool argument deltas - they will be sent synthetically in output_item.done + // Skip WebSearch tool argument deltas - they will be sent synthetically in output_item.done if bifrostResp.ItemID != nil { - if _, isWebSearch := webSearchItemIDs.Load(*bifrostResp.ItemID); isWebSearch { - return nil - } - if _, isWebFetch := webFetchItemIDs.Load(*bifrostResp.ItemID); isWebFetch { + streamState := getOrCreateAnthropicToResponsesStreamState(ctx) + if streamState.webSearchItemIDs[*bifrostResp.ItemID] { return nil } } @@ -1768,52 +1787,46 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp case schemas.ResponsesStreamResponseTypeOutputItemDone: // Handle WebSearch tool completion with sanitization and synthetic delta generation + if bifrostResp.Item != nil && + bifrostResp.Item.Type != nil && + *bifrostResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall && + bifrostResp.Item.ResponsesToolMessage != nil && + bifrostResp.Item.ResponsesToolMessage.Name != nil && + *bifrostResp.Item.ResponsesToolMessage.Name == "WebSearch" && + bifrostResp.Item.ResponsesToolMessage.Arguments != nil { - // check for claude-cli user agent - if ctx != nil { - if IsClaudeCodeRequest(ctx) { - // check for WebSearch tool - if bifrostResp.Item != nil && - bifrostResp.Item.Type != nil && - *bifrostResp.Item.Type == schemas.ResponsesMessageTypeFunctionCall && - bifrostResp.Item.ResponsesToolMessage != nil && - bifrostResp.Item.ResponsesToolMessage.Name != nil && - *bifrostResp.Item.ResponsesToolMessage.Name == "WebSearch" && - bifrostResp.Item.ResponsesToolMessage.Arguments != nil { - - argumentsJSON := sanitizeWebSearchArguments(*bifrostResp.Item.ResponsesToolMessage.Arguments) - bifrostResp.Item.ResponsesToolMessage.Arguments = &argumentsJSON - - // Generate synthetic input_json_delta events for the sanitized WebSearch arguments - // This replaces the delta events that were skipped earlier - var events []*AnthropicStreamEvent - - // Use OutputIndex for proper Anthropic indexing, fallback to ContentIndex - var indexToUse *int - if bifrostResp.OutputIndex != nil { - indexToUse = bifrostResp.OutputIndex - } else if bifrostResp.ContentIndex != nil { - indexToUse = bifrostResp.ContentIndex - } + argumentsJSON := sanitizeWebSearchArguments(*bifrostResp.Item.ResponsesToolMessage.Arguments) + bifrostResp.Item.ResponsesToolMessage.Arguments = &argumentsJSON + + // Generate synthetic input_json_delta events for the sanitized WebSearch arguments + // This replaces the delta events that were skipped earlier + var events []*AnthropicStreamEvent - deltaEvents := generateSyntheticInputJSONDeltas(argumentsJSON, indexToUse) - events = append(events, deltaEvents...) + // Use OutputIndex for proper Anthropic indexing, fallback to ContentIndex + var indexToUse *int + if bifrostResp.OutputIndex != nil { + indexToUse = bifrostResp.OutputIndex + } else if bifrostResp.ContentIndex != nil { + indexToUse = bifrostResp.ContentIndex + } - // Add the content_block_stop event at the end - stopEvent := &AnthropicStreamEvent{ - Type: AnthropicStreamEventTypeContentBlockStop, - Index: indexToUse, - } - events = append(events, stopEvent) + deltaEvents := generateSyntheticInputJSONDeltas(argumentsJSON, indexToUse) + events = append(events, deltaEvents...) - // Clean up the tracking for this WebSearch item - if bifrostResp.Item.ID != nil { - webSearchItemIDs.Delete(*bifrostResp.Item.ID) - } + // Add the content_block_stop event at the end + stopEvent := &AnthropicStreamEvent{ + Type: AnthropicStreamEventTypeContentBlockStop, + Index: indexToUse, + } + events = append(events, stopEvent) - return events - } + // Clean up the tracking for this WebSearch item + if bifrostResp.Item.ID != nil { + streamState := getOrCreateAnthropicToResponsesStreamState(ctx) + delete(streamState.webSearchItemIDs, *bifrostResp.Item.ID) } + + return events } if bifrostResp.Item != nil && diff --git a/core/providers/anthropic/websearch_test.go b/core/providers/anthropic/websearch_test.go new file mode 100644 index 0000000000..48c4214c40 --- /dev/null +++ b/core/providers/anthropic/websearch_test.go @@ -0,0 +1,270 @@ +package anthropic + +import ( + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +// TestWebSearch_OutputItemAdded_StoresID verifies that a WebSearch function_call +// output_item.added event stores the item ID in the per-request stream state so that +// subsequent argument deltas can be skipped. +func TestWebSearch_OutputItemAdded_StoresID(t *testing.T) { + t.Parallel() + + const itemID = "toolu_ws_storesid_test" + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + + bifrostResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + OutputIndex: schemas.Ptr(0), + Item: &schemas.ResponsesMessage{ + ID: schemas.Ptr(itemID), + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(itemID), + Name: schemas.Ptr("WebSearch"), + Arguments: schemas.Ptr(""), + }, + }, + } + + events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp) + + // Should emit content_block_start + if len(events) == 0 { + t.Fatal("expected at least one event") + } + if events[0].Type != AnthropicStreamEventTypeContentBlockStart { + t.Errorf("event[0].Type = %v, want content_block_start", events[0].Type) + } + if events[0].ContentBlock == nil || events[0].ContentBlock.Input == nil { + t.Fatal("expected ContentBlock with Input") + } + if string(events[0].ContentBlock.Input) != "{}" { + t.Errorf("ContentBlock.Input = %s, want {}", events[0].ContentBlock.Input) + } + + // ID must now be tracked in per-request state + state := getOrCreateAnthropicToResponsesStreamState(ctx) + if !state.webSearchItemIDs[itemID] { + t.Error("expected item ID to be stored in per-request stream state after output_item.added") + } +} + +// TestWebSearch_FunctionCallArgumentsDelta_Skipped verifies that argument deltas +// for a tracked WebSearch item are skipped (returning nil) regardless of the +// user agent — the fix for the original bug where non-Claude Code clients lost +// the query. +func TestWebSearch_FunctionCallArgumentsDelta_Skipped(t *testing.T) { + t.Parallel() + + const itemID = "toolu_ws_skip_test" + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + + // Pre-seed per-request state as if output_item.added already fired + state := getOrCreateAnthropicToResponsesStreamState(ctx) + state.webSearchItemIDs = map[string]bool{itemID: true} + + partial := `{"query": "world news"` + bifrostResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta, + OutputIndex: schemas.Ptr(0), + ItemID: schemas.Ptr(itemID), + Delta: &partial, + } + + events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp) + + if len(events) != 0 { + t.Errorf("expected deltas to be skipped (0 events), got %d", len(events)) + } +} + +// TestWebSearch_OutputItemDone_GeneratesSyntheticDeltas verifies that when +// output_item.done fires for a tracked WebSearch item, synthetic input_json_delta +// events carrying the full query are emitted, followed by content_block_stop. +// This applies for ALL clients regardless of user agent. +func TestWebSearch_OutputItemDone_GeneratesSyntheticDeltas(t *testing.T) { + t.Parallel() + + const itemID = "toolu_ws_synth_test" + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + + // Pre-seed per-request state as if output_item.added already fired + state := getOrCreateAnthropicToResponsesStreamState(ctx) + state.webSearchItemIDs = map[string]bool{itemID: true} + + query := `{"query":"world news today"}` + bifrostResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + OutputIndex: schemas.Ptr(1), + Item: &schemas.ResponsesMessage{ + ID: schemas.Ptr(itemID), + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(itemID), + Name: schemas.Ptr("WebSearch"), + Arguments: &query, + }, + }, + } + + events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp) + + // Must have at least one input_json_delta and a final content_block_stop + if len(events) < 2 { + t.Fatalf("expected at least 2 events (deltas + stop), got %d", len(events)) + } + + // All events except last must be input_json_delta + for i, ev := range events[:len(events)-1] { + if ev.Type != AnthropicStreamEventTypeContentBlockDelta { + t.Errorf("event[%d].Type = %v, want content_block_delta", i, ev.Type) + continue + } + if ev.Delta == nil || ev.Delta.Type != AnthropicStreamDeltaTypeInputJSON { + t.Errorf("event[%d].Delta.Type = %v, want input_json", i, ev.Delta) + } + } + + // Last event must be content_block_stop + last := events[len(events)-1] + if last.Type != AnthropicStreamEventTypeContentBlockStop { + t.Errorf("last event.Type = %v, want content_block_stop", last.Type) + } + + // Reconstruct the accumulated JSON from the deltas + var accumulated string + for _, ev := range events[:len(events)-1] { + if ev.Delta != nil && ev.Delta.PartialJSON != nil { + accumulated += *ev.Delta.PartialJSON + } + } + var got map[string]interface{} + if err := json.Unmarshal([]byte(accumulated), &got); err != nil { + t.Fatalf("accumulated JSON invalid: %v — got %q", err, accumulated) + } + if got["query"] != "world news today" { + t.Errorf("query = %v, want %q", got["query"], "world news today") + } + + // ID must have been cleaned up from per-request state + if state.webSearchItemIDs[itemID] { + t.Error("expected item ID to be removed from per-request stream state after output_item.done") + } +} + +// TestWebSearch_FullFlow_AnyUserAgent is the regression test for the original bug. +// It simulates the complete streaming sequence: +// +// output_item.added → FunctionCallArgumentsDelta (×N) → output_item.done +// +// and verifies that the client-facing Anthropic stream contains proper +// input_json_delta events with the query, regardless of user agent. +func TestWebSearch_FullFlow_AnyUserAgent(t *testing.T) { + t.Parallel() + + const itemID = "toolu_ws_fullflow_test" + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + + var allEvents []*AnthropicStreamEvent + + // Step 1: output_item.added + addedResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + OutputIndex: schemas.Ptr(0), + Item: &schemas.ResponsesMessage{ + ID: schemas.Ptr(itemID), + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(itemID), + Name: schemas.Ptr("WebSearch"), + Arguments: schemas.Ptr(""), + }, + }, + } + allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, addedResp)...) + + // Step 2: FunctionCallArgumentsDelta events (should be skipped) + for _, partial := range []string{`{"query": "`, `latest AI`, `news"}`} { + p := partial + deltaResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta, + OutputIndex: schemas.Ptr(0), + ItemID: schemas.Ptr(itemID), + Delta: &p, + } + allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, deltaResp)...) + } + + // Step 3: output_item.done with full accumulated arguments + fullArgs := `{"query":"latest AI news"}` + doneResp := &schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + OutputIndex: schemas.Ptr(0), + Item: &schemas.ResponsesMessage{ + ID: schemas.Ptr(itemID), + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(itemID), + Name: schemas.Ptr("WebSearch"), + Arguments: &fullArgs, + }, + }, + } + allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, doneResp)...) + + // Verify the sequence: + // [0] content_block_start (input:{}) + // [1..N-1] input_json_delta events + // [N] content_block_stop + if len(allEvents) < 3 { + t.Fatalf("expected at least 3 events, got %d: %v", len(allEvents), allEvents) + } + + // First event: content_block_start with empty input + if allEvents[0].Type != AnthropicStreamEventTypeContentBlockStart { + t.Errorf("allEvents[0].Type = %v, want content_block_start", allEvents[0].Type) + } + + // Last event: content_block_stop + last := allEvents[len(allEvents)-1] + if last.Type != AnthropicStreamEventTypeContentBlockStop { + t.Errorf("last event.Type = %v, want content_block_stop", last.Type) + } + + // Middle events: all input_json_delta + for i, ev := range allEvents[1 : len(allEvents)-1] { + if ev.Type != AnthropicStreamEventTypeContentBlockDelta { + t.Errorf("allEvents[%d].Type = %v, want content_block_delta", i+1, ev.Type) + } + if ev.Delta == nil || ev.Delta.Type != AnthropicStreamDeltaTypeInputJSON { + t.Errorf("allEvents[%d].Delta.Type = %v, want input_json", i+1, ev.Delta) + } + } + + // Reconstruct query from synthetic deltas + var accumulated string + for _, ev := range allEvents[1 : len(allEvents)-1] { + if ev.Delta != nil && ev.Delta.PartialJSON != nil { + accumulated += *ev.Delta.PartialJSON + } + } + var got map[string]interface{} + if err := json.Unmarshal([]byte(accumulated), &got); err != nil { + t.Fatalf("reconstructed JSON is invalid: %v — got %q", err, accumulated) + } + if got["query"] != "latest AI news" { + t.Errorf("reconstructed query = %v, want %q", got["query"], "latest AI news") + } +} diff --git a/transports/changelog.md b/transports/changelog.md index 4739b360ba..321cdde017 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1,3 +1,4 @@ - fix: case-insensitive `anthropic-beta` merge in `MergeBetaHeaders` - fix: Bedrock integration - update to use InvokeModelRawChunks for multi-event support [@tefimov](https://github.com/tefimov) - fix: gemini preserves thinkingLevel parameters during round-trip and finish reason mapping +- fix: WebSearch tool argument handling for all clients by removing the Claude Code user agent restriction