diff --git a/docs/proposals/004-vendor-specific-fields/proposal.md b/docs/proposals/004-vendor-specific-fields/proposal.md index af224f416e..1b1b66bc0e 100644 --- a/docs/proposals/004-vendor-specific-fields/proposal.md +++ b/docs/proposals/004-vendor-specific-fields/proposal.md @@ -36,23 +36,20 @@ type ChatCompletionRequest struct { // Vendor-specific fields are added as inline fields *GCPVertexAIVendorFields `json:",inline,omitempty"` - *AnthropicVendorFields `json:",inline,omitempty"` } // GCPVertexAIVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields. type GCPVertexAIVendorFields struct { // GenerationConfig holds Gemini generation configuration options. - GenerationConfig *GCPVertexAIGenerationConfig `json:"generationConfig,omitempty"` -} - -// GCPVertexAIGenerationConfig represents Gemini generation configuration options. -type GCPVertexAIGenerationConfig struct { - ThinkingConfig *genai.GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"` -} - -// AnthropicVendorFields contains GCP Anthropic-specific fields. -type AnthropicVendorFields struct { - Thinking *anthropic.ThinkingConfigParamUnion `json:"thinking,omitzero"` + // Currently only a subset of the options are supported. + // + // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig + GenerationConfig *GCPVertexAIGenerationConfig `json:"generationConfig,omitzero"` + + // SafetySettings: Safety settings in the request to block unsafe content in the response. + // + // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/SafetySetting + SafetySettings []*genai.SafetySetting `json:"safetySettings,omitzero"` } ``` diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index 89d56d5e1e..c2afcb7f06 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -819,6 +819,71 @@ type WebSearchLocation struct { Country string `json:"country,omitempty"` } +// ThinkingConfig contains thinking config for reasoning models +type ThinkingUnion struct { + OfEnabled *ThinkingEnabled `json:",omitzero,inline"` + OfDisabled *ThinkingDisabled `json:",omitzero,inline"` +} + +type ThinkingEnabled struct { + // Determines how many tokens the model can use for its internal reasoning process. + // Larger budgets can enable more thorough analysis for complex problems, improving + // response quality. + BudgetTokens int64 `json:"budget_tokens"` + // This field can be elided, and will marshal its zero value as "enabled". + Type string `json:"type"` + + // Optional. Indicates the thinking budget in tokens. + IncludeThoughts bool `json:"includeThoughts,omitempty"` +} + +type ThinkingDisabled struct { + Type string `json:"type,"` +} + +// MarshalJSON implements the json.Marshaler interface for ThinkingUnion. +func (t *ThinkingUnion) MarshalJSON() ([]byte, error) { + if t.OfEnabled != nil { + return json.Marshal(t.OfEnabled) + } + if t.OfDisabled != nil { + return json.Marshal(t.OfDisabled) + } + // If both are nil, return an empty object or an error, depending on your desired behavior. + return []byte(`{}`), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface for ThinkingUnion. +func (t *ThinkingUnion) UnmarshalJSON(data []byte) error { + // Use a temporary struct to determine the type + typeResult := gjson.GetBytes(data, "type") + if !typeResult.Exists() { + return errors.New("thinking config does not have a type") + } + + // Based on the 'type' field, unmarshal into the correct struct. + typeVal := typeResult.String() + + switch typeVal { + case "enabled": + var enabled ThinkingEnabled + if err := json.Unmarshal(data, &enabled); err != nil { + return err + } + t.OfEnabled = &enabled + case "disabled": + var disabled ThinkingDisabled + if err := json.Unmarshal(data, &disabled); err != nil { + return err + } + t.OfDisabled = &disabled + default: + return fmt.Errorf("invalid thinking union type: %s", typeVal) + } + + return nil +} + type ChatCompletionRequest struct { // Messages: A list of messages comprising the conversation so far. // Depending on the model you use, different message types (modalities) are supported, @@ -982,9 +1047,6 @@ type ChatCompletionRequest struct { // GCPVertexAIVendorFields configures the GCP VertexAI specific fields during schema translation. *GCPVertexAIVendorFields `json:",inline,omitempty"` - // AnthropicVendorFields configures the Anthropic specific fields during schema translation. - *AnthropicVendorFields `json:",inline,omitempty"` - // GuidedChoice: The output will be exactly one of the choices. GuidedChoice []string `json:"guided_choice,omitzero"` @@ -993,6 +1055,9 @@ type ChatCompletionRequest struct { // GuidedJSON: The output will follow the JSON schema. GuidedJSON json.RawMessage `json:"guided_json,omitzero"` + + // Thinking: The thinking config for reasoning models + Thinking *ThinkingUnion `json:"thinking,omitzero"` } type StreamOptions struct { @@ -1578,23 +1643,10 @@ type GCPVertexAIVendorFields struct { // GCPVertexAIGenerationConfig represents Gemini generation configuration options. type GCPVertexAIGenerationConfig struct { - // ThinkingConfig holds Gemini thinking configuration options. - // - // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig#ThinkingConfig - ThinkingConfig *genai.ThinkingConfig `json:"thinkingConfig,omitzero"` - // MediaResolution is to set global media resolution in gemini models: https://ai.google.dev/api/caching#MediaResolution MediaResolution genai.MediaResolution `json:"media_resolution,omitempty"` } -// AnthropicVendorFields contains Anthropic vendor-specific fields. -type AnthropicVendorFields struct { - // Thinking holds Anthropic thinking configuration options. - // - // https://docs.anthropic.com/en/api/messages#body-thinking - Thinking *anthropic.ThinkingConfigParamUnion `json:"thinking,omitzero"` -} - // ReasoningContentUnion content regarding the reasoning that is carried out by the model. // Reasoning refers to a Chain of Thought (CoT) that the model generates to enhance the accuracy of its final response. type ReasoningContentUnion struct { diff --git a/internal/apischema/openai/vendor_fields_test.go b/internal/apischema/openai/vendor_fields_test.go index d12509f2fb..2b5815abe6 100644 --- a/internal/apischema/openai/vendor_fields_test.go +++ b/internal/apischema/openai/vendor_fields_test.go @@ -16,7 +16,6 @@ import ( "github.com/openai/openai-go/v2/packages/param" "github.com/stretchr/testify/require" "google.golang.org/genai" - "k8s.io/utils/ptr" ) func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) { @@ -36,12 +35,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) { "content": "Hello, world!" } ], - "generationConfig": { - "thinkingConfig": { - "includeThoughts": true, - "thinkingBudget": 1000 - } - }, "safetySettings": [{ "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH" @@ -58,12 +51,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) { }, }, GCPVertexAIVendorFields: &GCPVertexAIVendorFields{ - GenerationConfig: &GCPVertexAIGenerationConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - IncludeThoughts: true, - ThinkingBudget: ptr.To(int32(1000)), - }, - }, SafetySettings: []*genai.SafetySetting{ { Category: genai.HarmCategoryHarassment, @@ -73,55 +60,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) { }, }, }, - { - name: "Request with multiple vendor fields", - jsonData: []byte(`{ - "model": "claude-3", - "messages": [ - { - "role": "user", - "content": "Multiple vendors test" - } - ], - "generationConfig": { - "thinkingConfig": { - "includeThoughts": true, - "thinkingBudget": 1000 - } - }, - "thinking": { - "type": "enabled", - "budget_tokens": 1000 - } - }`), - expected: &ChatCompletionRequest{ - Model: "claude-3", - Messages: []ChatCompletionMessageParamUnion{ - { - OfUser: &ChatCompletionUserMessageParam{ - Role: ChatMessageRoleUser, - Content: StringOrUserRoleContentUnion{Value: "Multiple vendors test"}, - }, - }, - }, - AnthropicVendorFields: &AnthropicVendorFields{ - Thinking: &anthropic.ThinkingConfigParamUnion{ - OfEnabled: &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: 1000, - Type: "enabled", - }, - }, - }, - GCPVertexAIVendorFields: &GCPVertexAIVendorFields{ - GenerationConfig: &GCPVertexAIGenerationConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - IncludeThoughts: true, - ThinkingBudget: ptr.To(int32(1000)), - }, - }, - }, - }, - }, { name: "Request without vendor fields", jsonData: []byte(`{ @@ -252,45 +190,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) { }, }, }, - { - name: "Request with both detail and thinkingConfig fields", - jsonData: []byte(`{ - "model": "gemini-1.5-pro", - "messages": [ - { - "role": "user", - "content": "Test with both detail and thinking config" - } - ], - "generationConfig": { - "media_resolution": "medium", - "thinkingConfig": { - "includeThoughts": true, - "thinkingBudget": 500 - } - } - }`), - expected: &ChatCompletionRequest{ - Model: "gemini-1.5-pro", - Messages: []ChatCompletionMessageParamUnion{ - { - OfUser: &ChatCompletionUserMessageParam{ - Role: ChatMessageRoleUser, - Content: StringOrUserRoleContentUnion{Value: "Test with both detail and thinking config"}, - }, - }, - }, - GCPVertexAIVendorFields: &GCPVertexAIVendorFields{ - GenerationConfig: &GCPVertexAIGenerationConfig{ - MediaResolution: "medium", - ThinkingConfig: &genai.ThinkingConfig{ - IncludeThoughts: true, - ThinkingBudget: ptr.To(int32(500)), - }, - }, - }, - }, - }, } for _, tt := range tests { diff --git a/internal/translator/openai_awsbedrock.go b/internal/translator/openai_awsbedrock.go index 10ce03fe44..df1c53e947 100644 --- a/internal/translator/openai_awsbedrock.go +++ b/internal/translator/openai_awsbedrock.go @@ -48,6 +48,29 @@ type openAIToAWSBedrockTranslatorV1ChatCompletion struct { activeToolStream bool } +func getAwsBedrockThinkingMap(tu *openai.ThinkingUnion) map[string]any { + if tu == nil { + return nil + } + + resultMap := make(map[string]any) + + if tu.OfEnabled != nil { + reasoningConfigMap := map[string]any{ + "type": "enabled", + "budget_tokens": tu.OfEnabled.BudgetTokens, + } + resultMap["thinking"] = reasoningConfigMap + } else if tu.OfDisabled != nil { + reasoningConfigMap := map[string]any{ + "type": "disabled", + } + resultMap["thinking"] = reasoningConfigMap + } + + return resultMap +} + // RequestBody implements [OpenAIChatCompletionTranslator.RequestBody]. func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) ( newHeaders []internalapi.Header, newBody []byte, err error, @@ -83,12 +106,12 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(_ []byte, ope bedrockReq.InferenceConfig.StopSequences = openAIReq.Stop.OfStringArray } - // Handle Anthropic vendor fields if present. Currently only supports thinking fields. - if openAIReq.AnthropicVendorFields != nil && openAIReq.Thinking != nil { + // Handle thinking config + if openAIReq.Thinking != nil { if bedrockReq.AdditionalModelRequestFields == nil { bedrockReq.AdditionalModelRequestFields = make(map[string]interface{}) } - bedrockReq.AdditionalModelRequestFields["thinking"] = openAIReq.Thinking + bedrockReq.AdditionalModelRequestFields = getAwsBedrockThinkingMap(openAIReq.Thinking) } // Convert Chat Completion messages. diff --git a/internal/translator/openai_awsbedrock_test.go b/internal/translator/openai_awsbedrock_test.go index 2d512716d2..a70ef23bff 100644 --- a/internal/translator/openai_awsbedrock_test.go +++ b/internal/translator/openai_awsbedrock_test.go @@ -17,7 +17,6 @@ import ( "testing" "time" - "github.com/anthropics/anthropic-sdk-go" "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -897,11 +896,10 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T) }, }, }, - AnthropicVendorFields: &openai.AnthropicVendorFields{ - Thinking: &anthropic.ThinkingConfigParamUnion{ - OfEnabled: &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: int64(1024), - }, + Thinking: &openai.ThinkingUnion{ + OfEnabled: &openai.ThinkingEnabled{ + BudgetTokens: int64(1024), + Type: "enabled", }, }, }, @@ -1115,12 +1113,10 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T) }, }, }, - AnthropicVendorFields: &openai.AnthropicVendorFields{ - Thinking: &anthropic.ThinkingConfigParamUnion{ - OfEnabled: &anthropic.ThinkingConfigEnabledParam{ - Type: "enabled", - BudgetTokens: 1024, - }, + Thinking: &openai.ThinkingUnion{ + OfEnabled: &openai.ThinkingEnabled{ + Type: "enabled", + BudgetTokens: 1024, }, }, }, @@ -1149,11 +1145,9 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T) }, }, }, - AnthropicVendorFields: &openai.AnthropicVendorFields{ - Thinking: &anthropic.ThinkingConfigParamUnion{ - OfDisabled: &anthropic.ThinkingConfigDisabledParam{ - Type: "disabled", - }, + Thinking: &openai.ThinkingUnion{ + OfDisabled: &openai.ThinkingDisabled{ + Type: "disabled", }, }, }, diff --git a/internal/translator/openai_gcpanthropic.go b/internal/translator/openai_gcpanthropic.go index 1c3ac4f4d7..60873b0751 100644 --- a/internal/translator/openai_gcpanthropic.go +++ b/internal/translator/openai_gcpanthropic.go @@ -546,6 +546,28 @@ func openAIToAnthropicMessages(openAIMsgs []openai.ChatCompletionMessageParamUni return } +// NewThinkingConfigParamUnion converts a ThinkingUnion into a ThinkingConfigParamUnion. +func getThinkingConfigParamUnion(tu *openai.ThinkingUnion) *anthropic.ThinkingConfigParamUnion { + if tu == nil { + return nil + } + + result := &anthropic.ThinkingConfigParamUnion{} + + if tu.OfEnabled != nil { + result.OfEnabled = &anthropic.ThinkingConfigEnabledParam{ + BudgetTokens: tu.OfEnabled.BudgetTokens, + Type: constant.Enabled(tu.OfEnabled.Type), + } + } else if tu.OfDisabled != nil { + result.OfDisabled = &anthropic.ThinkingConfigDisabledParam{ + Type: constant.Disabled(tu.OfDisabled.Type), + } + } + + return result +} + // buildAnthropicParams is a helper function that translates an OpenAI request // into the parameter struct required by the Anthropic SDK. func buildAnthropicParams(openAIReq *openai.ChatCompletionRequest) (params *anthropic.MessageNewParams, err error) { @@ -595,11 +617,8 @@ func buildAnthropicParams(openAIReq *openai.ChatCompletionRequest) (params *anth // 5. Handle Vendor specific fields. // Since GCPAnthropic follows the Anthropic API, we also check for Anthropic vendor fields. - if openAIReq.AnthropicVendorFields != nil { - anthVendorFields := openAIReq.AnthropicVendorFields - if anthVendorFields.Thinking != nil { - params.Thinking = *anthVendorFields.Thinking - } + if openAIReq.Thinking != nil { + params.Thinking = *getThinkingConfigParamUnion(openAIReq.Thinking) } return params, nil diff --git a/internal/translator/openai_gcpanthropic_test.go b/internal/translator/openai_gcpanthropic_test.go index 4c07994ffe..3044eac5f8 100644 --- a/internal/translator/openai_gcpanthropic_test.go +++ b/internal/translator/openai_gcpanthropic_test.go @@ -312,9 +312,11 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_RequestBody(t *testing.T Model: claudeTestModel, Messages: []openai.ChatCompletionMessageParamUnion{}, MaxTokens: ptr.To(int64(100)), - AnthropicVendorFields: &openai.AnthropicVendorFields{ - Thinking: &anthropic.ThinkingConfigParamUnion{ - OfEnabled: &anthropic.ThinkingConfigEnabledParam{}, + Thinking: &openai.ThinkingUnion{ + OfEnabled: &openai.ThinkingEnabled{ + BudgetTokens: 100, + Type: "enabled", + IncludeThoughts: true, }, }, } @@ -335,9 +337,9 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_RequestBody(t *testing.T Model: claudeTestModel, Messages: []openai.ChatCompletionMessageParamUnion{}, MaxTokens: ptr.To(int64(100)), - AnthropicVendorFields: &openai.AnthropicVendorFields{ - Thinking: &anthropic.ThinkingConfigParamUnion{ - OfDisabled: &anthropic.ThinkingConfigDisabledParam{}, + Thinking: &openai.ThinkingUnion{ + OfDisabled: &openai.ThinkingDisabled{ + Type: "disabled", }, }, } diff --git a/internal/translator/openai_gcpvertexai.go b/internal/translator/openai_gcpvertexai.go index ae9d402183..c42def23bb 100644 --- a/internal/translator/openai_gcpvertexai.go +++ b/internal/translator/openai_gcpvertexai.go @@ -391,6 +391,33 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) convertGCPChunkToOpenAI( } } +// NewGenerationConfigThinkingConfig converts a ThinkingUnion to GenerationConfigThinkingConfig. +// It maps the values from the populated field of the union to the target struct. +func getGenerationConfigThinkingConfig(tu *openai.ThinkingUnion) *genai.ThinkingConfig { + if tu == nil { + return nil + } + + result := &genai.ThinkingConfig{} + + if tu.OfEnabled != nil { + + result.IncludeThoughts = tu.OfEnabled.IncludeThoughts + + // Convert int64 to int32, + //nolint:gosec // G115: BudgetTokens is known to be within int32 range. + budget := int32(tu.OfEnabled.BudgetTokens) + result.ThinkingBudget = &budget + } else if tu.OfDisabled != nil { + // If thinking is disabled, the target config should have default values. + // The `omitempty` tags will ensure they aren't marshaled. + result.IncludeThoughts = false + result.ThinkingBudget = nil + } + + return result +} + // openAIMessageToGeminiMessage converts an OpenAI ChatCompletionRequest to a GCP Gemini GenerateContentRequest. func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) openAIMessageToGeminiMessage(openAIReq *openai.ChatCompletionRequest, requestModel internalapi.RequestModel) (*gcp.GenerateContentRequest, error) { // Convert OpenAI messages to Gemini Contents and SystemInstruction. @@ -427,6 +454,9 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) openAIMessageToGeminiMes GenerationConfig: generationConfig, SystemInstruction: systemInstruction, } + if openAIReq.Thinking != nil { + gcr.GenerationConfig.ThinkingConfig = getGenerationConfigThinkingConfig(openAIReq.Thinking) + } // Apply vendor-specific fields after standard OpenAI-to-Gemini translation. // Vendor fields take precedence over translated fields when conflicts occur. @@ -450,9 +480,6 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) applyVendorSpecificField if gcr.GenerationConfig == nil { gcr.GenerationConfig = &genai.GenerationConfig{} } - if vendorGenConfig.ThinkingConfig != nil { - gcr.GenerationConfig.ThinkingConfig = vendorGenConfig.ThinkingConfig - } if vendorGenConfig.MediaResolution != "" && mediaResolutionAvailable(requestModel) { gcr.GenerationConfig.MediaResolution = vendorGenConfig.MediaResolution } diff --git a/internal/translator/openai_gcpvertexai_test.go b/internal/translator/openai_gcpvertexai_test.go index 1e3edb7a93..7a16024c6d 100644 --- a/internal/translator/openai_gcpvertexai_test.go +++ b/internal/translator/openai_gcpvertexai_test.go @@ -548,12 +548,11 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T) }, }, }, - GCPVertexAIVendorFields: &openai.GCPVertexAIVendorFields{ - GenerationConfig: &openai.GCPVertexAIGenerationConfig{ - ThinkingConfig: &genai.ThinkingConfig{ - IncludeThoughts: true, - ThinkingBudget: ptr.To(int32(1000)), - }, + Thinking: &openai.ThinkingUnion{ + OfEnabled: &openai.ThinkingEnabled{ + IncludeThoughts: true, + BudgetTokens: 1000, + Type: "enabled", }, }, },