Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ type WebSearchLocation struct {
type ThinkingUnion struct {
OfEnabled *ThinkingEnabled `json:",omitzero,inline"`
OfDisabled *ThinkingDisabled `json:",omitzero,inline"`
OfAdaptive *ThinkingAdaptive `json:",omitzero,inline"`
}

type ThinkingEnabled struct {
Expand All @@ -887,6 +888,10 @@ type ThinkingDisabled struct {
Type string `json:"type,"`
}

type ThinkingAdaptive struct {
Type string `json:"type,"`
}

// MarshalJSON implements the json.Marshaler interface for ThinkingUnion.
func (t *ThinkingUnion) MarshalJSON() ([]byte, error) {
if t.OfEnabled != nil {
Expand All @@ -895,6 +900,9 @@ func (t *ThinkingUnion) MarshalJSON() ([]byte, error) {
if t.OfDisabled != nil {
return json.Marshal(t.OfDisabled)
}
if t.OfAdaptive != nil {
return json.Marshal(t.OfAdaptive)
}
// If both are nil, return an empty object or an error, depending on your desired behavior.
return []byte(`{}`), nil
}
Expand Down Expand Up @@ -923,6 +931,12 @@ func (t *ThinkingUnion) UnmarshalJSON(data []byte) error {
return err
}
t.OfDisabled = &disabled
case "adaptive":
var adaptive ThinkingAdaptive
if err := json.Unmarshal(data, &adaptive); err != nil {
return err
}
t.OfAdaptive = &adaptive
default:
return fmt.Errorf("invalid thinking union type: %s", typeVal)
}
Expand Down
109 changes: 109 additions & 0 deletions internal/apischema/openai/union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"testing"

"github.com/stretchr/testify/require"

"github.com/envoyproxy/ai-gateway/internal/json"
)

// TestUnmarshalJSONNestedUnion tests the completion API prompt parsing.
Expand Down Expand Up @@ -285,3 +287,110 @@ func TestUnmarshalJSONEmbeddingInput_Errors(t *testing.T) {
})
}
}

func TestThinkingUnion_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
data string
expect ThinkingUnion
}{
{
name: "enabled",
data: `{"type":"enabled","budget_tokens":1024}`,
expect: ThinkingUnion{
OfEnabled: &ThinkingEnabled{Type: "enabled", BudgetTokens: 1024},
},
},
{
name: "disabled",
data: `{"type":"disabled"}`,
expect: ThinkingUnion{
OfDisabled: &ThinkingDisabled{Type: "disabled"},
},
},
{
name: "adaptive",
data: `{"type":"adaptive"}`,
expect: ThinkingUnion{
OfAdaptive: &ThinkingAdaptive{Type: "adaptive"},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var got ThinkingUnion
err := json.Unmarshal([]byte(tc.data), &got)
require.NoError(t, err)
require.Equal(t, tc.expect, got)
})
}
}

func TestThinkingUnion_UnmarshalJSON_Errors(t *testing.T) {
tests := []struct {
name string
data string
expectedErr string
}{
{
name: "missing type field",
data: `{"budget_tokens":1024}`,
expectedErr: "thinking config does not have a type",
},
{
name: "invalid type value",
data: `{"type":"unknown"}`,
expectedErr: "invalid thinking union type: unknown",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var got ThinkingUnion
err := json.Unmarshal([]byte(tc.data), &got)
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectedErr)
})
}
}

func TestThinkingUnion_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input ThinkingUnion
expect string
}{
{
name: "enabled",
input: ThinkingUnion{
OfEnabled: &ThinkingEnabled{Type: "enabled", BudgetTokens: 1024},
},
expect: `{"budget_tokens":1024,"type":"enabled"}`,
},
{
name: "disabled",
input: ThinkingUnion{
OfDisabled: &ThinkingDisabled{Type: "disabled"},
},
expect: `{"type":"disabled"}`,
},
{
name: "adaptive",
input: ThinkingUnion{
OfAdaptive: &ThinkingAdaptive{Type: "adaptive"},
},
expect: `{"type":"adaptive"}`,
},
{
name: "all nil returns empty object",
input: ThinkingUnion{},
expect: `{}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := json.Marshal(&tc.input)
require.NoError(t, err)
require.JSONEq(t, tc.expect, string(got))
})
}
}
Loading