diff --git a/client/sse.go b/client/sse.go index 1c0d62c4..788aff82 100644 --- a/client/sse.go +++ b/client/sse.go @@ -462,12 +462,7 @@ func (c *SSEMCPClient) GetPrompt( return nil, err } - var result mcp.GetPromptResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil + return mcp.ParseGetPromptResult(response) } func (c *SSEMCPClient) ListTools( @@ -496,12 +491,7 @@ func (c *SSEMCPClient) CallTool( return nil, err } - var result mcp.CallToolResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil + return mcp.ParseCallToolResult(response) } func (c *SSEMCPClient) SetLevel( diff --git a/client/sse_test.go b/client/sse_test.go index 81f0d057..6bff34f3 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -25,7 +25,7 @@ func TestSSEMCPClient(t *testing.T) { mcp.WithString("parameter-1", mcp.Description("A string tool parameter")), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ - Content: []interface{}{ + Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string), diff --git a/client/stdio.go b/client/stdio.go index 02f16469..da9f0025 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -388,12 +388,7 @@ func (c *StdioMCPClient) GetPrompt( return nil, err } - var result mcp.GetPromptResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil + return mcp.ParseGetPromptResult(response) } func (c *StdioMCPClient) ListTools( @@ -422,12 +417,7 @@ func (c *StdioMCPClient) CallTool( return nil, err } - var result mcp.CallToolResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil + return mcp.ParseCallToolResult(response) } func (c *StdioMCPClient) SetLevel( diff --git a/examples/everything/main.go b/examples/everything/main.go index 3634d2f7..7639e488 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -280,7 +280,7 @@ func (s *MCPServer) handleEchoTool( return nil, fmt.Errorf("invalid message argument") } return &mcp.CallToolResult{ - Content: []interface{}{ + Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: fmt.Sprintf("Echo: %s", message), @@ -301,7 +301,7 @@ func (s *MCPServer) handleAddTool( } sum := a + b return &mcp.CallToolResult{ - Content: []interface{}{ + Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: fmt.Sprintf("The sum of %f and %f is %f.", a, b, sum), @@ -330,7 +330,7 @@ func (s *MCPServer) handleSendNotification( } return &mcp.CallToolResult{ - Content: []interface{}{ + Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: "notification sent successfully", @@ -373,7 +373,7 @@ func (s *MCPServer) handleLongRunningOperationTool( } return &mcp.CallToolResult{ - Content: []interface{}{ + Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: fmt.Sprintf( @@ -412,7 +412,7 @@ func (s *MCPServer) handleGetTinyImageTool( request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ - Content: []interface{}{ + Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: "This is a tiny image:", diff --git a/mcp/prompts.go b/mcp/prompts.go index bf40839b..bc12a729 100644 --- a/mcp/prompts.go +++ b/mcp/prompts.go @@ -77,18 +77,8 @@ const ( // This is similar to `SamplingMessage`, but also supports the embedding of // resources from the MCP server. type PromptMessage struct { - Role Role `json:"role"` - Content interface{} `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource -} - -// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. -// -// It is up to the client how best to render embedded resources for the -// benefit of the LLM and/or the user. -type EmbeddedResource struct { - Annotated - Type string `json:"type"` - Resource ResourceContents `json:"resource"` + Role Role `json:"role"` + Content Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource } // PromptListChangedNotification is an optional notification from the server diff --git a/mcp/tools.go b/mcp/tools.go index f80d6739..6edb2d4f 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -25,7 +25,7 @@ type ListToolsResult struct { // should be reported as an MCP error response. type CallToolResult struct { Result - Content []interface{} `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource + Content []Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). diff --git a/mcp/types.go b/mcp/types.go index aabf0243..771e9878 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -585,6 +585,10 @@ type Annotated struct { } `json:"annotations,omitempty"` } +type Content interface { + isContent() +} + // TextContent represents text provided to or from an LLM. // It must have Type set to "text". type TextContent struct { @@ -594,6 +598,8 @@ type TextContent struct { Text string `json:"text"` } +func (TextContent) isContent() {} + // ImageContent represents an image provided to or from an LLM. // It must have Type set to "image". type ImageContent struct { @@ -605,6 +611,20 @@ type ImageContent struct { MIMEType string `json:"mimeType"` } +func (ImageContent) isContent() {} + +// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. +// +// It is up to the client how best to render embedded resources for the +// benefit of the LLM and/or the user. +type EmbeddedResource struct { + Annotated + Type string `json:"type"` + Resource ResourceContents `json:"resource"` +} + +func (EmbeddedResource) isContent() {} + // ModelPreferences represents the server's preferences for model selection, // requested of the client during sampling. // diff --git a/mcp/utils.go b/mcp/utils.go index 07afbb83..c425b9df 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -1,6 +1,9 @@ package mcp -import "fmt" +import ( + "encoding/json" + "fmt" +) // ClientRequest types var _ ClientRequest = &PingRequest{} @@ -180,7 +183,7 @@ func NewLoggingMessageNotification( } // Helper function to create a new PromptMessage -func NewPromptMessage(role Role, content interface{}) PromptMessage { +func NewPromptMessage(role Role, content Content) PromptMessage { return PromptMessage{ Role: role, Content: content, @@ -215,7 +218,7 @@ func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { // NewToolResultText creates a new CallToolResult with a text content func NewToolResultText(text string) *CallToolResult { return &CallToolResult{ - Content: []interface{}{ + Content: []Content{ TextContent{ Type: "text", Text: text, @@ -227,7 +230,7 @@ func NewToolResultText(text string) *CallToolResult { // NewToolResultError creates a new CallToolResult that indicates an error func NewToolResultError(errText string) *CallToolResult { return &CallToolResult{ - Content: []interface{}{ + Content: []Content{ TextContent{ Type: "text", Text: errText, @@ -240,7 +243,7 @@ func NewToolResultError(errText string) *CallToolResult { // NewToolResultImage creates a new CallToolResult with both text and image content func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { return &CallToolResult{ - Content: []interface{}{ + Content: []Content{ TextContent{ Type: "text", Text: text, @@ -260,7 +263,7 @@ func NewToolResultResource( resource ResourceContents, ) *CallToolResult { return &CallToolResult{ - Content: []interface{}{ + Content: []Content{ TextContent{ Type: "text", Text: text, @@ -304,8 +307,7 @@ func NewReadResourceResult(text string) *ReadResourceResult { return &ReadResourceResult{ Contents: []interface{}{ TextResourceContents{ - ResourceContents: ResourceContents{}, - Text: text, + Text: text, }, }, } @@ -364,3 +366,180 @@ func NewInitializeResult( func FormatNumberResult(value float64) *CallToolResult { return NewToolResultText(fmt.Sprintf("%.2f", value)) } + +func ExtractString(data map[string]any, key string) string { + if value, ok := data[key]; ok { + if str, ok := value.(string); ok { + return str + } + } + return "" +} + +func ExtractMap(data map[string]any, key string) map[string]any { + if value, ok := data[key]; ok { + if m, ok := value.(map[string]any); ok { + return m + } + } + return nil +} + +func ParseContent(contentMap map[string]any) (Content, error) { + contentType := ExtractString(contentMap, "type") + + switch contentType { + case "text": + text := ExtractString(contentMap, "text") + if text == "" { + return nil, fmt.Errorf("text is missing") + } + return NewTextContent(text), nil + + case "image": + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("image data or mimeType is missing") + } + return NewImageContent(data, mimeType), nil + + case "resource": + resourceMap := ExtractMap(contentMap, "resource") + if resourceMap == nil { + return nil, fmt.Errorf("resource is missing") + } + + uri := ExtractString(resourceMap, "uri") + mimeType := ExtractString(resourceMap, "mimeType") + text := ExtractString(resourceMap, "text") + + if uri == "" || mimeType == "" { + return nil, fmt.Errorf("resource uri or mimeType is missing") + } + + if text != "" { + return NewEmbeddedResource( + ResourceContents{ + URI: uri, + MIMEType: mimeType, + }, + ), nil + } + } + + return nil, fmt.Errorf("unsupported content type: %s", contentType) +} + +func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + result := GetPromptResult{} + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + description, ok := jsonContent["description"] + if ok { + if descriptionStr, ok := description.(string); ok { + result.Description = descriptionStr + } + } + + messages, ok := jsonContent["messages"] + if ok { + messagesArr, ok := messages.([]any) + if !ok { + return nil, fmt.Errorf("messages is not an array") + } + + for _, message := range messagesArr { + messageMap, ok := message.(map[string]any) + if !ok { + return nil, fmt.Errorf("message is not an object") + } + + // Extract role + roleStr := ExtractString(messageMap, "role") + if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) { + return nil, fmt.Errorf("unsupported role: %s", roleStr) + } + + // Extract content + contentMap, ok := messageMap["content"].(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + // Append processed message + result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content)) + + } + } + + return &result, nil +} + +func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result CallToolResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + isError, ok := jsonContent["isError"] + if ok { + if isErrorBool, ok := isError.(bool); ok { + result.IsError = isErrorBool + } + } + + contents, ok := jsonContent["content"] + if !ok { + return nil, fmt.Errorf("content is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("content is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + result.Content = append(result.Content, content) + } + + return &result, nil +} diff --git a/server/server_test.go b/server/server_test.go index ff2bf299..7830f121 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -3,10 +3,11 @@ package server import ( "context" "encoding/json" - "github.com/mark3labs/mcp-go/mcp" - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" ) func TestMCPServer_NewMCPServer(t *testing.T) { @@ -693,7 +694,7 @@ func createTestServer() *MCPServer { }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ - Content: []interface{}{ + Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: "test result",