diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index e2ede0993d..de1e21284f 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -869,6 +869,7 @@ type WebSearchLocation struct { type ThinkingUnion struct { OfEnabled *ThinkingEnabled `json:",omitzero,inline"` OfDisabled *ThinkingDisabled `json:",omitzero,inline"` + OfAdaptive *ThinkingAdaptive `json:",omitzero,inline"` } type ThinkingEnabled struct { @@ -887,6 +888,10 @@ type ThinkingDisabled struct { Type string `json:"type,"` } +type ThinkingAdaptive struct { + Type string `json:"type,"` +} + // MarshalJSON implements the json.Marshaler interface for ThinkingUnion. func (t *ThinkingUnion) MarshalJSON() ([]byte, error) { if t.OfEnabled != nil { @@ -895,6 +900,9 @@ func (t *ThinkingUnion) MarshalJSON() ([]byte, error) { if t.OfDisabled != nil { return json.Marshal(t.OfDisabled) } + if t.OfAdaptive != nil { + return json.Marshal(t.OfAdaptive) + } // If both are nil, return an empty object or an error, depending on your desired behavior. return []byte(`{}`), nil } @@ -923,6 +931,12 @@ func (t *ThinkingUnion) UnmarshalJSON(data []byte) error { return err } t.OfDisabled = &disabled + case "adaptive": + var adaptive ThinkingAdaptive + if err := json.Unmarshal(data, &adaptive); err != nil { + return err + } + t.OfAdaptive = &adaptive default: return fmt.Errorf("invalid thinking union type: %s", typeVal) } diff --git a/internal/apischema/openai/union_test.go b/internal/apischema/openai/union_test.go index 83e8283687..1b37277ff0 100644 --- a/internal/apischema/openai/union_test.go +++ b/internal/apischema/openai/union_test.go @@ -9,6 +9,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/json" ) // TestUnmarshalJSONNestedUnion tests the completion API prompt parsing. @@ -285,3 +287,110 @@ func TestUnmarshalJSONEmbeddingInput_Errors(t *testing.T) { }) } } + +func TestThinkingUnion_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + data string + expect ThinkingUnion + }{ + { + name: "enabled", + data: `{"type":"enabled","budget_tokens":1024}`, + expect: ThinkingUnion{ + OfEnabled: &ThinkingEnabled{Type: "enabled", BudgetTokens: 1024}, + }, + }, + { + name: "disabled", + data: `{"type":"disabled"}`, + expect: ThinkingUnion{ + OfDisabled: &ThinkingDisabled{Type: "disabled"}, + }, + }, + { + name: "adaptive", + data: `{"type":"adaptive"}`, + expect: ThinkingUnion{ + OfAdaptive: &ThinkingAdaptive{Type: "adaptive"}, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var got ThinkingUnion + err := json.Unmarshal([]byte(tc.data), &got) + require.NoError(t, err) + require.Equal(t, tc.expect, got) + }) + } +} + +func TestThinkingUnion_UnmarshalJSON_Errors(t *testing.T) { + tests := []struct { + name string + data string + expectedErr string + }{ + { + name: "missing type field", + data: `{"budget_tokens":1024}`, + expectedErr: "thinking config does not have a type", + }, + { + name: "invalid type value", + data: `{"type":"unknown"}`, + expectedErr: "invalid thinking union type: unknown", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var got ThinkingUnion + err := json.Unmarshal([]byte(tc.data), &got) + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectedErr) + }) + } +} + +func TestThinkingUnion_MarshalJSON(t *testing.T) { + tests := []struct { + name string + input ThinkingUnion + expect string + }{ + { + name: "enabled", + input: ThinkingUnion{ + OfEnabled: &ThinkingEnabled{Type: "enabled", BudgetTokens: 1024}, + }, + expect: `{"budget_tokens":1024,"type":"enabled"}`, + }, + { + name: "disabled", + input: ThinkingUnion{ + OfDisabled: &ThinkingDisabled{Type: "disabled"}, + }, + expect: `{"type":"disabled"}`, + }, + { + name: "adaptive", + input: ThinkingUnion{ + OfAdaptive: &ThinkingAdaptive{Type: "adaptive"}, + }, + expect: `{"type":"adaptive"}`, + }, + { + name: "all nil returns empty object", + input: ThinkingUnion{}, + expect: `{}`, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := json.Marshal(&tc.input) + require.NoError(t, err) + require.JSONEq(t, tc.expect, string(got)) + }) + } +}