diff --git a/mcp/tools.go b/mcp/tools.go index 80ae50911..42e888d52 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -653,6 +653,31 @@ func (tis ToolArgumentsSchema) MarshalJSON() ([]byte, error) { return json.Marshal(m) } +// UnmarshalJSON implements the json.Unmarshaler interface for ToolArgumentsSchema. +// It handles both "$defs" (JSON Schema 2019-09+) and "definitions" (JSON Schema draft-07) +// by reading either field and storing it in the Defs field. +func (tis *ToolArgumentsSchema) UnmarshalJSON(data []byte) error { + // Use a temporary type to avoid infinite recursion + type Alias ToolArgumentsSchema + aux := &struct { + Definitions map[string]any `json:"definitions,omitempty"` + *Alias + }{ + Alias: (*Alias)(tis), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + // If $defs wasn't provided but definitions was, use definitions + if tis.Defs == nil && aux.Definitions != nil { + tis.Defs = aux.Definitions + } + + return nil +} + type ToolAnnotation struct { // Human-readable title for the tool Title string `json:"title,omitempty"` diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 16aeb5dfc..ef472e477 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -1549,3 +1549,136 @@ func TestToolMetaMarshalingOmitsWhenNil(t *testing.T) { // Check that _meta field is not present assert.NotContains(t, result, "_meta", "Tool without Meta should not include _meta field") } + +func TestToolArgumentsSchema_UnmarshalWithDefinitions(t *testing.T) { + // Test that "definitions" (JSON Schema draft-07) is properly unmarshaled into Defs field + jsonData := `{ + "type": "object", + "properties": { + "operation": { + "$ref": "#/definitions/operation_type" + } + }, + "required": ["operation"], + "definitions": { + "operation_type": { + "type": "string", + "enum": ["create", "read", "update", "delete"] + } + } + }` + + var schema ToolArgumentsSchema + err := json.Unmarshal([]byte(jsonData), &schema) + assert.NoError(t, err) + + // Verify the schema was properly unmarshaled + assert.Equal(t, "object", schema.Type) + assert.Contains(t, schema.Properties, "operation") + assert.Equal(t, []string{"operation"}, schema.Required) + + // Most importantly: verify that "definitions" was read into Defs field + assert.NotNil(t, schema.Defs) + assert.Contains(t, schema.Defs, "operation_type") + + operationType, ok := schema.Defs["operation_type"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "string", operationType["type"]) + assert.NotNil(t, operationType["enum"]) +} + +func TestToolArgumentsSchema_UnmarshalWithDefs(t *testing.T) { + // Test that "$defs" (JSON Schema 2019-09+) is properly unmarshaled into Defs field + jsonData := `{ + "type": "object", + "properties": { + "operation": { + "$ref": "#/$defs/operation_type" + } + }, + "required": ["operation"], + "$defs": { + "operation_type": { + "type": "string", + "enum": ["create", "read", "update", "delete"] + } + } + }` + + var schema ToolArgumentsSchema + err := json.Unmarshal([]byte(jsonData), &schema) + assert.NoError(t, err) + + // Verify the schema was properly unmarshaled + assert.Equal(t, "object", schema.Type) + assert.Contains(t, schema.Properties, "operation") + assert.Equal(t, []string{"operation"}, schema.Required) + + // Verify that "$defs" was read into Defs field + assert.NotNil(t, schema.Defs) + assert.Contains(t, schema.Defs, "operation_type") + + operationType, ok := schema.Defs["operation_type"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "string", operationType["type"]) + assert.NotNil(t, operationType["enum"]) +} + +func TestToolArgumentsSchema_UnmarshalPrefersDefs(t *testing.T) { + // Test that if both "$defs" and "definitions" are present, "$defs" takes precedence + jsonData := `{ + "type": "object", + "$defs": { + "from_defs": { + "type": "string" + } + }, + "definitions": { + "from_definitions": { + "type": "integer" + } + } + }` + + var schema ToolArgumentsSchema + err := json.Unmarshal([]byte(jsonData), &schema) + assert.NoError(t, err) + + // $defs should take precedence + assert.Contains(t, schema.Defs, "from_defs") + assert.NotContains(t, schema.Defs, "from_definitions") +} + +func TestToolArgumentsSchema_MarshalRoundTrip(t *testing.T) { + // Test that marshaling and unmarshaling preserves definitions + original := ToolArgumentsSchema{ + Type: "object", + Properties: map[string]any{ + "field": map[string]any{ + "$ref": "#/$defs/my_type", + }, + }, + Required: []string{"field"}, + Defs: map[string]any{ + "my_type": map[string]any{ + "type": "string", + "enum": []string{"a", "b", "c"}, + }, + }, + } + + // Marshal + data, err := json.Marshal(original) + assert.NoError(t, err) + + // Unmarshal + var unmarshaled ToolArgumentsSchema + err = json.Unmarshal(data, &unmarshaled) + assert.NoError(t, err) + + // Verify round-trip + assert.Equal(t, original.Type, unmarshaled.Type) + assert.Equal(t, original.Required, unmarshaled.Required) + assert.NotNil(t, unmarshaled.Defs) + assert.Contains(t, unmarshaled.Defs, "my_type") +}