Skip to content

Commit 43c8148

Browse files
committed
fixes
1 parent 88a4847 commit 43c8148

File tree

8 files changed

+164
-20
lines changed

8 files changed

+164
-20
lines changed

client/client.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,10 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra
508508
if contentMap, ok := params.Messages[i].Content.(map[string]any); ok {
509509
// Parse the content map into a proper Content type
510510
content, err := mcp.ParseContent(contentMap)
511-
if err == nil {
512-
params.Messages[i].Content = content
511+
if err != nil {
512+
return nil, fmt.Errorf("failed to parse content for message %d: %w", i, err)
513513
}
514+
params.Messages[i].Content = content
514515
}
515516
}
516517

client/transport/streamable_http.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,13 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool {
594594
func (c *StreamableHTTP) listenForever(ctx context.Context) {
595595
c.logger.Infof("listening to server forever")
596596
for {
597-
// Use the original context for continuous listening - no timeout
597+
// Use the original context for continuous listening - no per-iteration timeout
598+
// The SSE connection itself will detect disconnections via the underlying HTTP transport,
599+
// and the context cancellation will propagate from the parent to stop listening gracefully.
600+
// We don't add an artificial timeout here because:
601+
// 1. Persistent SSE connections are meant to stay open indefinitely
602+
// 2. Network-level timeouts and keep-alives handle connection health
603+
// 3. Context cancellation (user-initiated or system shutdown) provides clean shutdown
598604
err := c.createGETConnectionToServer(ctx)
599605
if errors.Is(err, ErrGetMethodNotAllowed) {
600606
// server does not support listening

mcp/utils.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,12 @@ func ToBoolPtr(b bool) *bool {
945945
// GetTextFromContent extracts text from a Content interface that might be a TextContent struct
946946
// or a map[string]any that was unmarshaled from JSON. This is useful when dealing with content
947947
// that comes from different transport layers that may handle JSON differently.
948+
//
949+
// This function uses fallback behavior for non-text content - it returns a string representation
950+
// via fmt.Sprintf for any content that cannot be extracted as text. This is a lossy operation
951+
// intended for convenience in logging and display scenarios.
952+
//
953+
// For strict type validation, use ParseContent() instead, which returns an error for invalid content.
948954
func GetTextFromContent(content any) string {
949955
switch c := content.(type) {
950956
case TextContent:

server/stdio.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,18 @@ func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
606606
if err := json.Unmarshal(response.Result, &result); err != nil {
607607
samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
608608
} else {
609-
samplingResp.result = &result
609+
// Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent)
610+
if contentMap, ok := result.Content.(map[string]any); ok {
611+
content, err := mcp.ParseContent(contentMap)
612+
if err != nil {
613+
samplingResp.err = fmt.Errorf("failed to parse sampling response content: %w", err)
614+
} else {
615+
result.Content = content
616+
samplingResp.result = &result
617+
}
618+
} else {
619+
samplingResp.result = &result
620+
}
610621
}
611622
}
612623

server/streamable_http.go

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,21 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
235235

236236
// --- internal methods ---
237237

238+
// getOrCreateSession retrieves an existing persistent session or creates a new ephemeral one.
239+
// Persistent sessions are used for bidirectional communication (sampling, elicitation) and are
240+
// stored in activeSessions. Ephemeral sessions are created per-request for stateless operation.
241+
func (s *StreamableHTTPServer) getOrCreateSession(sessionID string) *streamableHttpSession {
242+
// Check if a persistent session exists (created by continuous listening GET connection)
243+
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
244+
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
245+
return persistentSession
246+
}
247+
}
248+
249+
// Create ephemeral session if no persistent session exists
250+
return newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels)
251+
}
252+
238253
func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
239254
// post request carry request/notification message
240255

@@ -309,18 +324,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
309324
}
310325
}
311326

312-
// Check if a persistent session exists (for sampling support), otherwise create ephemeral session
313-
var session *streamableHttpSession
314-
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
315-
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
316-
session = persistentSession
317-
}
318-
}
319-
320-
// Create ephemeral session if no persistent session exists
321-
if session == nil {
322-
session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels)
323-
}
327+
// Get or create session (reuses persistent sessions for bidirectional communication)
328+
session := s.getOrCreateSession(sessionID)
324329

325330
// Set the client context before handling the message
326331
ctx := s.server.WithContext(r.Context(), session)
@@ -431,15 +436,16 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
431436
sessionID = uuid.New().String()
432437
}
433438

434-
// Check if session already exists, if so reuse it for sampling
439+
// Check if session already exists (reuse for persistent bidirectional communication)
440+
sessionInterface, sessionExists := s.activeSessions.Load(sessionID)
435441
var session *streamableHttpSession
436-
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
442+
if sessionExists {
437443
if existingSession, ok := sessionInterface.(*streamableHttpSession); ok {
438444
session = existingSession
439445
}
440446
}
441447

442-
// Create new session if none exists
448+
// Create and register new persistent session if none exists
443449
if session == nil {
444450
session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels)
445451
if err := s.server.RegisterSession(r.Context(), session); err != nil {
@@ -448,7 +454,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
448454
}
449455
defer s.server.UnregisterSession(r.Context(), sessionID)
450456

451-
// Register session for sampling response delivery
457+
// Store session for bidirectional communication (sampling/elicitation response delivery)
452458
s.activeSessions.Store(sessionID, session)
453459
defer s.activeSessions.Delete(sessionID)
454460
}
@@ -923,6 +929,17 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
923929
if err := json.Unmarshal(response.result, &result); err != nil {
924930
return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err)
925931
}
932+
933+
// Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent)
934+
// HTTP transport unmarshals Content as map[string]any, we need to convert it to the proper type
935+
if contentMap, ok := result.Content.(map[string]any); ok {
936+
content, err := mcp.ParseContent(contentMap)
937+
if err != nil {
938+
return nil, fmt.Errorf("failed to parse sampling response content: %w", err)
939+
}
940+
result.Content = content
941+
}
942+
926943
return &result, nil
927944
case <-ctx.Done():
928945
return nil, ctx.Err()

www/docs/pages/clients/advanced-sampling.mdx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@ Learn how to implement MCP clients that can handle sampling requests from server
66

77
Sampling allows MCP clients to respond to LLM completion requests from servers. When a server needs to generate content, answer questions, or perform reasoning tasks, it can send a sampling request to the client, which then processes it using an LLM and returns the result.
88

9+
:::danger[Critical Security Requirement]
10+
Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling#user-interaction-model), sampling implementations **SHOULD** always include a human in the loop with the ability to deny sampling requests.
11+
12+
**You MUST implement approval flows that:**
13+
- Present each sampling request to the user for review before execution
14+
- Allow users to view and edit prompts before sending to the LLM
15+
- Display generated responses for user approval before returning to the server
16+
- Provide clear UI to accept or reject requests at each stage
17+
18+
**Without human approval, your implementation:**
19+
- Allows servers to make unauthorized LLM requests without user consent
20+
- May expose sensitive information through unreviewed prompts
21+
- Creates uncontrolled API costs from automated sampling
22+
- Violates user trust and security best practices
23+
24+
The examples below show basic handler implementation. **You must add approval logic** before using in production.
25+
:::
26+
927
## Implementing a Sampling Handler
1028

1129
Create a sampling handler by implementing the `SamplingHandler` interface:

www/docs/pages/servers/advanced-sampling.mdx

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,25 @@ Learn how to implement MCP servers that can request LLM completions from clients
66

77
Sampling allows MCP servers to request LLM completions from clients, enabling bidirectional communication where servers can leverage client-side LLM capabilities. This is particularly useful for tools that need to generate content, answer questions, or perform reasoning tasks.
88

9+
:::info[User Consent Required]
10+
Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling#user-interaction-model), clients **SHOULD** implement human-in-the-loop approval for sampling requests.
11+
12+
When you request sampling from a client:
13+
- The user will typically be prompted to review and approve your request
14+
- The user may modify your prompts before sending to their LLM
15+
- The user may reject your request entirely
16+
- Response times may be longer due to user interaction
17+
18+
**Design your tools accordingly:**
19+
- Provide clear descriptions of why sampling is needed
20+
- Use descriptive system prompts explaining the purpose
21+
- Handle rejection errors gracefully
22+
- Consider timeouts for user approval delays
23+
- Don't assume immediate or automatic approval
24+
25+
Well-designed sampling requests improve user trust and approval rates.
26+
:::
27+
928
## Enabling Sampling
1029

1130
To enable sampling in your server, call `EnableSampling()` during server setup:

www/docs/pages/transports/http.mdx

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,24 @@ The headers are automatically populated by the transport layer and are available
693693

694694
StreamableHTTP transport now supports bidirectional sampling, allowing servers to request LLM completions from clients. This enables advanced scenarios where servers can leverage client-side LLM capabilities.
695695

696+
:::warning[Security: Human-in-the-Loop Required]
697+
Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling), implementations **SHOULD** always include a human in the loop with the ability to deny sampling requests.
698+
699+
**Your sampling handler implementation MUST:**
700+
- Present sampling requests to users for review before execution
701+
- Allow users to view and edit prompts before sending to the LLM
702+
- Display generated responses for approval before returning to the server
703+
- Provide clear UI to accept or reject sampling requests
704+
705+
Failing to implement approval flows creates serious security and trust risks, including:
706+
- Servers making unauthorized LLM requests on behalf of users
707+
- Exposure of sensitive data through unreviewed prompts
708+
- Uncontrolled API costs from automated sampling
709+
- Lack of user consent for AI interactions
710+
711+
See the [example implementation](#example-with-approval-flow) below for a reference approval pattern.
712+
:::
713+
696714
### Requirements for Sampling
697715

698716
To enable sampling with StreamableHTTP transport, the client **must** use the `WithContinuousListening()` option:
@@ -777,6 +795,54 @@ mcpServer.AddTool(mcp.Tool{
777795
- Without continuous listening, the transport operates in stateless request/response mode only
778796
- Network interruptions may require reconnection and re-establishment of the sampling channel
779797

798+
### Example with Approval Flow
799+
800+
Here's a reference implementation showing proper human-in-the-loop approval:
801+
802+
```go
803+
type ApprovalSamplingHandler struct {
804+
llmClient LLMClient // Your actual LLM client
805+
ui UserInterface // Your UI for presenting requests to users
806+
}
807+
808+
func (h *ApprovalSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
809+
// Step 1: Present the sampling request to the user for review
810+
approved, modifiedRequest, err := h.ui.PresentSamplingRequest(ctx, request)
811+
if err != nil {
812+
return nil, fmt.Errorf("failed to get user approval: %w", err)
813+
}
814+
815+
if !approved {
816+
return nil, fmt.Errorf("user rejected sampling request")
817+
}
818+
819+
// Step 2: Send the approved/modified request to the LLM
820+
response, err := h.llmClient.CreateCompletion(ctx, modifiedRequest)
821+
if err != nil {
822+
return nil, fmt.Errorf("LLM request failed: %w", err)
823+
}
824+
825+
// Step 3: Present the response to the user for final approval
826+
approved, modifiedResponse, err := h.ui.PresentSamplingResponse(ctx, response)
827+
if err != nil {
828+
return nil, fmt.Errorf("failed to get response approval: %w", err)
829+
}
830+
831+
if !approved {
832+
return nil, fmt.Errorf("user rejected sampling response")
833+
}
834+
835+
// Step 4: Return the approved response to the server
836+
return modifiedResponse, nil
837+
}
838+
```
839+
840+
**Key Points:**
841+
- Users must explicitly approve both the request (before sending to LLM) and the response (before returning to server)
842+
- Users can modify prompts or responses before approval
843+
- Rejection at any stage returns an error to the server
844+
- The UI should clearly display what the server is requesting and why
845+
780846
## Next Steps
781847

782848
- **[In-Process Transport](/transports/inprocess)** - Learn about embedded scenarios

0 commit comments

Comments
 (0)