Skip to content

Commit 19fe996

Browse files
committed
feat(MCPServer): send log messages to specific client
1 parent 077f546 commit 19fe996

File tree

4 files changed

+282
-7
lines changed

4 files changed

+282
-7
lines changed

mcp/types.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ const (
6767
// MethodNotificationToolsListChanged notifies when the list of available tools changes.
6868
// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/
6969
MethodNotificationToolsListChanged = "notifications/tools/list_changed"
70+
71+
// MethodNotificationMessage notifies when severs send log messages.
72+
// https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#log-message-notifications
73+
MethodNotificationMessage MCPMethod = "notifications/message"
7074
)
7175

7276
type URITemplate struct {
@@ -734,6 +738,30 @@ const (
734738
LoggingLevelEmergency LoggingLevel = "emergency"
735739
)
736740

741+
var (
742+
// Map logging level constants to numerical codes as specified in RFC-5424
743+
levelToSeverity = func() map[LoggingLevel]int {
744+
return map[LoggingLevel]int {
745+
LoggingLevelEmergency: 0,
746+
LoggingLevelAlert: 1,
747+
LoggingLevelCritical: 2,
748+
LoggingLevelError: 3,
749+
LoggingLevelWarning: 4,
750+
LoggingLevelNotice: 5,
751+
LoggingLevelInfo: 6,
752+
LoggingLevelDebug: 7,
753+
}
754+
}()
755+
)
756+
757+
// Allows is a helper function that decides a message could be sent to client or not according to the logging level
758+
func (subscribedLevel LoggingLevel) Allows(currentLevel LoggingLevel) (bool, error) {
759+
if _, ok := levelToSeverity[currentLevel]; !ok {
760+
return false, fmt.Errorf("illegal message logging level:%s", currentLevel)
761+
}
762+
return levelToSeverity[subscribedLevel] >= levelToSeverity[currentLevel], nil
763+
}
764+
737765
/* Sampling */
738766

739767
// CreateMessageRequest is a request from the server to sample an LLM via the

mcp/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func NewLoggingMessageNotification(
171171
) LoggingMessageNotification {
172172
return LoggingMessageNotification{
173173
Notification: Notification{
174-
Method: "notifications/message",
174+
Method: string(MethodNotificationMessage),
175175
},
176176
Params: struct {
177177
Level LoggingLevel `json:"level"`

server/session.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,69 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error
343343

344344
return nil
345345
}
346+
347+
// SendLogMessageToClient sends a log message to the current client
348+
func(s *MCPServer) SendLogMessageToClient(ctx context.Context, msg mcp.LoggingMessageNotification) error {
349+
if s.capabilities.logging == nil || !(*s.capabilities.logging) {
350+
return fmt.Errorf("server does not support emitting log message notifications")
351+
}
352+
353+
clientSession := ClientSessionFromContext(ctx)
354+
if clientSession == nil || !clientSession.Initialized() {
355+
return ErrSessionNotInitialized
356+
}
357+
358+
logSession, ok := clientSession.(SessionWithLogging)
359+
if !ok {
360+
return ErrSessionDoesNotSupportLogging
361+
}
362+
363+
// Servers send notifications containing severity levels, optional logger names, and arbitrary JSON-serializable data.
364+
// see <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging>
365+
if msg.Params.Level == "" || msg.Params.Data == nil {
366+
return fmt.Errorf("invalid log message without level or data")
367+
}
368+
369+
clientLogLevel := logSession.GetLogLevel()
370+
allowed, err := clientLogLevel.Allows(msg.Params.Level)
371+
if err != nil {
372+
return err
373+
}
374+
if !allowed {
375+
return fmt.Errorf("message level(%s) is lower than client level(%s)", msg.Params.Level, clientLogLevel)
376+
}
377+
378+
notification := mcp.JSONRPCNotification{
379+
JSONRPC: mcp.JSONRPC_VERSION,
380+
Notification: mcp.Notification{
381+
Method: msg.Method,
382+
Params: mcp.NotificationParams{
383+
AdditionalFields: map[string]any{
384+
"level": msg.Params.Level,
385+
"data": msg.Params.Data,
386+
"logger": msg.Params.Logger,
387+
},
388+
},
389+
},
390+
}
391+
392+
select {
393+
case logSession.NotificationChannel() <- notification:
394+
return nil
395+
default:
396+
// Channel is blocked, if there's an error hook, use it
397+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
398+
err := ErrNotificationChannelBlocked
399+
// Copy hooks pointer to local variable to avoid race condition
400+
hooks := s.hooks
401+
go func(sessionID string, hooks *Hooks) {
402+
// Use the error hook to report the blocked channel
403+
hooks.onError(ctx, nil, "notification", map[string]any{
404+
"method": msg.Method,
405+
"sessionID": sessionID,
406+
}, fmt.Errorf("failed to send a log message, notification channel blocked for session %s: %w", sessionID, err))
407+
}(logSession.SessionID(), hooks)
408+
}
409+
return ErrNotificationChannelBlocked
410+
}
411+
}

server/session_test.go

Lines changed: 187 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ type sessionTestClientWithLogging struct {
104104
sessionID string
105105
notificationChannel chan mcp.JSONRPCNotification
106106
initialized bool
107-
loggingLevel atomic.Value
107+
loggingLevel atomic.Value
108108
}
109109

110110
func (f *sessionTestClientWithLogging) SessionID() string {
@@ -136,9 +136,9 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
136136

137137
// Verify that all implementations satisfy their respective interfaces
138138
var (
139-
_ ClientSession = (*sessionTestClient)(nil)
140-
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
141-
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
139+
_ ClientSession = (*sessionTestClient)(nil)
140+
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
141+
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
142142
)
143143

144144
func TestSessionWithTools_Integration(t *testing.T) {
@@ -1039,6 +1039,187 @@ func TestMCPServer_SetLevel(t *testing.T) {
10391039

10401040
// Check logging level
10411041
if session.GetLogLevel() != mcp.LoggingLevelCritical {
1042-
t.Errorf("Expected critical level, got %v", session.GetLogLevel())
1042+
t.Errorf("Expected critical level, got %s", session.GetLogLevel())
10431043
}
1044-
}
1044+
}
1045+
1046+
func TestMCPServer_SendLogMessageToClientDisabled(t *testing.T) {
1047+
// Create server without logging capability
1048+
server := NewMCPServer("test-server", "1.0.0")
1049+
1050+
// Create and initialize a session
1051+
sessionChan := make(chan mcp.JSONRPCNotification, 10)
1052+
session := &sessionTestClientWithLogging{
1053+
sessionID: "session-1",
1054+
notificationChannel: sessionChan,
1055+
}
1056+
session.Initialize()
1057+
1058+
// Mock a request context
1059+
ctx := server.WithContext(context.Background(), session)
1060+
1061+
// Try to send a log message to client when capability is disabled
1062+
require.Error(t, server.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelCritical, "test logger", "test data")))
1063+
}
1064+
1065+
func TestMCPServer_SendLogMessageToClient(t *testing.T) {
1066+
// Prepare a log message
1067+
logMsg := mcp.NewLoggingMessageNotification(
1068+
mcp.LoggingLevelAlert,
1069+
"test logger",
1070+
"test data",
1071+
)
1072+
1073+
tests := []struct {
1074+
name string
1075+
contextPrepare func(context.Context, *MCPServer) context.Context
1076+
validate func(*testing.T, context.Context, *MCPServer)
1077+
}{
1078+
{
1079+
name: "no active session",
1080+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
1081+
return ctx
1082+
},
1083+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
1084+
require.Error(t, srv.SendLogMessageToClient(ctx, logMsg))
1085+
},
1086+
},
1087+
{
1088+
name: "uninit session",
1089+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
1090+
logSession := &sessionTestClientWithLogging{
1091+
sessionID: "test",
1092+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1093+
initialized: false,
1094+
}
1095+
return srv.WithContext(ctx, logSession)
1096+
},
1097+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
1098+
require.Error(t, srv.SendLogMessageToClient(ctx, logMsg))
1099+
_, ok := ClientSessionFromContext(ctx).(*sessionTestClientWithLogging)
1100+
require.True(t, ok, "session not found or of incorrect type")
1101+
},
1102+
},
1103+
{
1104+
name: "session not supports logging",
1105+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
1106+
logSession := &sessionTestClientWithTools{
1107+
sessionID: "test",
1108+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1109+
initialized: false,
1110+
}
1111+
return srv.WithContext(ctx, logSession)
1112+
},
1113+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
1114+
require.Error(t, srv.SendLogMessageToClient(ctx, logMsg))
1115+
},
1116+
},
1117+
{
1118+
name: "invalid log messages without level or data",
1119+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
1120+
logSession := &sessionTestClientWithLogging{
1121+
sessionID: "test",
1122+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1123+
initialized: false,
1124+
}
1125+
logSession.Initialize()
1126+
return srv.WithContext(ctx, logSession)
1127+
},
1128+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
1129+
// Invalid message without level
1130+
require.Error(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification("", "test logger", "test data")))
1131+
// Invalid message with illegal level
1132+
require.Error(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevel("invalid level"), "test logger", "test data")))
1133+
// Invalid message without data
1134+
require.Error(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelCritical, "test logger", nil)))
1135+
},
1136+
},
1137+
{
1138+
name: "active session",
1139+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
1140+
logSession := &sessionTestClientWithLogging{
1141+
sessionID: "test",
1142+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1143+
initialized: false,
1144+
}
1145+
logSession.Initialize()
1146+
return srv.WithContext(ctx, logSession)
1147+
},
1148+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
1149+
for range 10 {
1150+
require.NoError(t, srv.SendLogMessageToClient(ctx, logMsg))
1151+
}
1152+
session, ok := ClientSessionFromContext(ctx).(*sessionTestClientWithLogging)
1153+
require.True(t, ok, "session not found or of incorrect type")
1154+
for range 10 {
1155+
select {
1156+
case msg := <-session.notificationChannel:
1157+
assert.Equal(t, string(mcp.MethodNotificationMessage), msg.Method)
1158+
assert.Equal(t, mcp.LoggingLevelAlert, msg.Params.AdditionalFields["level"])
1159+
assert.Equal(t, "test logger", msg.Params.AdditionalFields["logger"])
1160+
assert.Equal(t, "test data", msg.Params.AdditionalFields["data"])
1161+
default:
1162+
t.Errorf("log message not sent")
1163+
}
1164+
}
1165+
},
1166+
},
1167+
{
1168+
name: "session with blocked channel",
1169+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
1170+
logSession := &sessionTestClientWithLogging{
1171+
sessionID: "test",
1172+
notificationChannel: make(chan mcp.JSONRPCNotification, 1),
1173+
initialized: false,
1174+
}
1175+
logSession.Initialize()
1176+
return srv.WithContext(ctx, logSession)
1177+
},
1178+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
1179+
require.NoError(t, srv.SendLogMessageToClient(ctx, logMsg))
1180+
require.Error(t, srv.SendLogMessageToClient(ctx, logMsg))
1181+
},
1182+
},
1183+
{
1184+
name: "send log messages of different levels",
1185+
contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context {
1186+
logSession := &sessionTestClientWithLogging{
1187+
sessionID: "test",
1188+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1189+
initialized: false,
1190+
}
1191+
logSession.Initialize()
1192+
// Set client log level to "Error"
1193+
logSession.SetLogLevel(mcp.LoggingLevelError)
1194+
return srv.WithContext(ctx, logSession)
1195+
},
1196+
validate: func(t *testing.T, ctx context.Context, srv *MCPServer) {
1197+
// Log messages of higher level than client level could be sent
1198+
require.NoError(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelEmergency, "test logger", "")))
1199+
require.NoError(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelAlert, "test logger", "")))
1200+
require.NoError(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelCritical, "test logger", "")))
1201+
require.NoError(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelError, "test logger", "")))
1202+
1203+
// Log messages of lower level than client level could not be sent
1204+
require.Error(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelWarning, "test logger", "")))
1205+
require.Error(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelNotice, "test logger", "")))
1206+
require.Error(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelInfo, "test logger", "")))
1207+
require.Error(t, srv.SendLogMessageToClient(ctx, mcp.NewLoggingMessageNotification(mcp.LoggingLevelDebug, "test logger", "")))
1208+
1209+
logSession, ok := ClientSessionFromContext(ctx).(*sessionTestClientWithLogging)
1210+
require.True(t, ok, "session not found or of incorrect type")
1211+
1212+
// Confirm four log messages were received
1213+
require.Equal(t, len(logSession.notificationChannel), 4)
1214+
},
1215+
},
1216+
}
1217+
for _, tt := range tests {
1218+
t.Run(tt.name, func(t *testing.T) {
1219+
server := NewMCPServer("test-server", "1.0.0", WithLogging())
1220+
ctx := tt.contextPrepare(context.Background(), server)
1221+
1222+
tt.validate(t, ctx, server)
1223+
})
1224+
}
1225+
}

0 commit comments

Comments
 (0)