Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions internal/server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
)

// apiRouter creates a router that represents the routes under /api
Expand Down Expand Up @@ -64,17 +63,6 @@ func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
span.SetStatus(codes.Error, err.Error())
}
span.End()

status := "success"
if err != nil {
status = "error"
}
s.instrumentation.ToolsetGet.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.name", toolsetName)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()

toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
Expand All @@ -101,18 +89,8 @@ func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
span.SetStatus(codes.Error, err.Error())
}
span.End()

status := "success"
if err != nil {
status = "error"
}
s.instrumentation.ToolGet.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.name", toolName)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()

tool, ok := s.ResourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
Expand Down Expand Up @@ -146,17 +124,6 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
span.SetStatus(codes.Error, err.Error())
}
span.End()

status := "success"
if err != nil {
status = "error"
}
s.instrumentation.ToolInvoke.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.name", toolName)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()

tool, ok := s.ResourceMgr.GetTool(toolName)
Expand Down
142 changes: 113 additions & 29 deletions internal/server/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,45 @@ func (s *stdioSession) Start(ctx context.Context) error {

// readInputStream reads requests/notifications from MCP clients through stdin
func (s *stdioSession) readInputStream(ctx context.Context) error {
sessionStart := time.Now()

// Define attributes for session metrics
// Note: mcp.protocol.version is added dynamically after protocol negotiation
sessionAttrs := []attribute.KeyValue{
attribute.String("network.transport", "pipe"),
attribute.String("network.protocol.name", "stdio"),
}

s.server.instrumentation.McpActiveSessions.Add(ctx, 1, metric.WithAttributes(sessionAttrs...))

var err error
defer func() {
// Build full attributes including mcp.protocol.version if negotiated
fullAttrs := sessionAttrs
if s.protocol != "" {
fullAttrs = append(fullAttrs, attribute.String("mcp.protocol.version", s.protocol))
}

// Decrement active sessions counter
s.server.instrumentation.McpActiveSessions.Add(ctx, -1, metric.WithAttributes(fullAttrs...))

// Record session duration
sessionDuration := time.Since(sessionStart).Seconds()
durationAttrs := make([]attribute.KeyValue, len(fullAttrs))
copy(durationAttrs, fullAttrs)
if err != nil && err != io.EOF {
durationAttrs = append(durationAttrs, attribute.String("error.type", err.Error()))
}
s.server.instrumentation.McpSessionDuration.Record(ctx, sessionDuration, metric.WithAttributes(durationAttrs...))
}()

for {
if err := ctx.Err(); err != nil {
if err = ctx.Err(); err != nil {
return err
}
line, err := s.readLine(ctx)

var line string
line, err = s.readLine(ctx)
if err != nil {
if err == io.EOF {
return nil
Expand All @@ -206,7 +240,9 @@ func (s *stdioSession) readInputStream(ctx context.Context) error {
)
defer span.End()

v, res, err := processMcpMessage(msgCtx, []byte(line), s.server, s.protocol, "", "", nil, "")
var v string
var res any
v, res, err = processMcpMessage(msgCtx, []byte(line), s.server, s.protocol, "", "", nil, "")
if err != nil {
// errors during the processing of message will generate a valid MCP Error response.
// server can continue to run.
Expand Down Expand Up @@ -309,6 +345,7 @@ func mcpRouter(s *Server) (chi.Router, error) {

// sseHandler handles sse initialization and message.
func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
sessionStart := time.Now()
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse",
trace.WithSpanKind(trace.SpanKindServer),
)
Expand All @@ -325,23 +362,34 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")

// Define attributes for session metrics
networkProtocolVersion := fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor)
sessionAttrs := []attribute.KeyValue{
attribute.String("network.transport", "tcp"),
attribute.String("network.protocol.name", "http"),
attribute.String("network.protocol.version", networkProtocolVersion),
attribute.String("mcp.protocol.version", "2024-11-05"),
attribute.String("toolset.name", toolsetName),
}

// Increment active sessions counter
s.instrumentation.McpActiveSessions.Add(ctx, 1, metric.WithAttributes(sessionAttrs...))

var err error
defer func() {
// Decrement active sessions counter
s.instrumentation.McpActiveSessions.Add(ctx, -1, metric.WithAttributes(sessionAttrs...))

// Record session duration
sessionDuration := time.Since(sessionStart).Seconds()
durationAttrs := make([]attribute.KeyValue, len(sessionAttrs))
copy(durationAttrs, sessionAttrs)
if err != nil {
span.SetStatus(codes.Error, err.Error())
durationAttrs = append(durationAttrs, attribute.String("error.type", err.Error()))
}
s.instrumentation.McpSessionDuration.Record(ctx, sessionDuration, metric.WithAttributes(durationAttrs...))
span.End()
status := "success"
if err != nil {
status = "error"
}
s.instrumentation.McpSse.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)),
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()

flusher, ok := w.(http.Flusher)
Expand Down Expand Up @@ -474,17 +522,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
span.SetStatus(codes.Error, err.Error())
}
span.End()

status := "success"
if err != nil {
status = "error"
}
s.instrumentation.McpPost.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()

networkProtocolVersion := fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor)
Expand Down Expand Up @@ -540,6 +577,8 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {

// processMcpMessage process the messages received from clients
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header, networkProtocolVersion string) (string, any, error) {
operationStart := time.Now()

logger, err := util.LoggerFromContext(ctx)
if err != nil {
return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
Expand Down Expand Up @@ -590,6 +629,44 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
networkProtocolName = "http"
}

var metricErrorType string
genAIAttrs := &util.GenAIMetricAttrs{
NetworkProtocolName: networkProtocolName,
NetworkProtocolVersion: networkProtocolVersion,
}
ctx = util.WithGenAIMetricAttrs(ctx, genAIAttrs)

// Record operation duration metric on function exit
defer func() {
operationDuration := time.Since(operationStart).Seconds()
durationAttrs := []attribute.KeyValue{
attribute.String("mcp.method.name", baseMessage.Method),
attribute.String("network.transport", networkTransport),
attribute.String("network.protocol.name", networkProtocolName),
attribute.String("toolset.name", toolsetName),
}
if protocolVersion != "" {
durationAttrs = append(durationAttrs, attribute.String("mcp.protocol.version", protocolVersion))
}
if networkProtocolVersion != "" {
durationAttrs = append(durationAttrs, attribute.String("network.protocol.version", networkProtocolVersion))
}
// Add gen_ai attributes populated by method handlers
if genAIAttrs.OperationName != "" {
durationAttrs = append(durationAttrs, attribute.String("gen_ai.operation.name", genAIAttrs.OperationName))
}
if genAIAttrs.ToolName != "" {
durationAttrs = append(durationAttrs, attribute.String("gen_ai.tool.name", genAIAttrs.ToolName))
}
if genAIAttrs.PromptName != "" {
durationAttrs = append(durationAttrs, attribute.String("gen_ai.prompt.name", genAIAttrs.PromptName))
}
if metricErrorType != "" {
durationAttrs = append(durationAttrs, attribute.String("error.type", metricErrorType))
}
s.instrumentation.McpOperationDuration.Record(ctx, operationDuration, metric.WithAttributes(durationAttrs...))
}()

// Set required semantic attributes for span according to OTEL MCP semcov
// ref: https://opentelemetry.io/docs/specs/semconv/gen-ai/mcp/#server
span.SetAttributes(
Expand Down Expand Up @@ -625,14 +702,18 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
return "", nil, err
}

// Add instrumentation to context for use in method handlers
ctx = util.WithInstrumentation(ctx, s.instrumentation)

// Process the method
switch baseMessage.Method {
case mcputil.INITIALIZE:
result, version, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version)
if err != nil {
span.SetStatus(codes.Error, err.Error())
if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok {
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
metricErrorType = rpcErr.Error.String()
span.SetAttributes(attribute.String("error.type", metricErrorType))
}
return "", result, err
}
Expand All @@ -643,25 +724,28 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
if !ok {
err := fmt.Errorf("toolset does not exist")
rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil)
metricErrorType = rpcErr.Error.String()
span.SetStatus(codes.Error, err.Error())
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
span.SetAttributes(attribute.String("error.type", metricErrorType))
return "", rpcErr, err
}
promptset, ok := s.ResourceMgr.GetPromptset(promptsetName)
if !ok {
err := fmt.Errorf("promptset does not exist")
rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil)
metricErrorType = rpcErr.Error.String()
span.SetStatus(codes.Error, err.Error())
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
span.SetAttributes(attribute.String("error.type", metricErrorType))
return "", rpcErr, err
}
result, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header)
if err != nil {
span.SetStatus(codes.Error, err.Error())
// Set error.type based on JSON-RPC error code
if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok {
metricErrorType = rpcErr.Error.String()
span.SetAttributes(attribute.Int("jsonrpc.error.code", rpcErr.Error.Code))
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
span.SetAttributes(attribute.String("error.type", metricErrorType))
}
}
return "", result, err
Expand Down
42 changes: 42 additions & 0 deletions internal/server/mcp/v20241105/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"net/http"
"time"

"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -111,12 +113,19 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
attribute.String("gen_ai.tool.name", toolName),
attribute.String("gen_ai.operation.name", "execute_tool"),
)

tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

// Populate gen_ai attributes for operation duration metric
if genAIAttrs := util.GenAIMetricAttrsFromContext(ctx); genAIAttrs != nil {
genAIAttrs.OperationName = "execute_tool"
genAIAttrs.ToolName = toolName
}

// Get access token
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
Expand Down Expand Up @@ -209,8 +218,34 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

// Get instrumentation for recording tool execution duration
instrumentation, instrumentationErr := util.InstrumentationFromContext(ctx)

// run tool invocation and generate response.
executionStart := time.Now()
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
executionDuration := time.Since(executionStart).Seconds()

// Record tool execution duration metric
if instrumentationErr == nil {
execAttrs := []attribute.KeyValue{
attribute.String("gen_ai.tool.name", toolName),
}
// Add network protocol attributes from context
if genAIAttrs := util.GenAIMetricAttrsFromContext(ctx); genAIAttrs != nil {
if genAIAttrs.NetworkProtocolName != "" {
execAttrs = append(execAttrs, attribute.String("network.protocol.name", genAIAttrs.NetworkProtocolName))
}
if genAIAttrs.NetworkProtocolVersion != "" {
execAttrs = append(execAttrs, attribute.String("network.protocol.version", genAIAttrs.NetworkProtocolVersion))
}
}
if err != nil {
execAttrs = append(execAttrs, attribute.String("error.type", err.Error()))
}
instrumentation.ToolExecutionDuration.Record(ctx, executionDuration, metric.WithAttributes(execAttrs...))
}

if err != nil {
var tbErr util.ToolboxError

Expand Down Expand Up @@ -325,12 +360,19 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
span := trace.SpanFromContext(ctx)
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))

prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

// Populate gen_ai attributes for operation duration metric
if genAIAttrs := util.GenAIMetricAttrsFromContext(ctx); genAIAttrs != nil {
genAIAttrs.OperationName = "get_prompt"
genAIAttrs.PromptName = promptName
}

// Parse the arguments provided in the request.
argValues, err := prompt.ParseArgs(req.Params.Arguments, nil)
if err != nil {
Expand Down
Loading
Loading