Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions router-tests/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func TestMCP(t *testing.T) {
})
})

t.Run("Tool name collision with built-in tool uses prefixed name when OmitToolNamePrefix is enabled", func(t *testing.T) {
t.Run("Tool name collision with built-in tool skips colliding operation when OmitToolNamePrefix is enabled", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
MCPOperationsPath: "testdata/mcp_operations_collision",
MCP: config.MCPConfiguration{
Expand All @@ -234,8 +234,8 @@ func TestMCP(t *testing.T) {
toolNames[i] = tool.Name
}

assert.Contains(t, toolNames, "get_schema") // built-in tool (ExposeSchema=true)
assert.Contains(t, toolNames, "execute_operation_get_schema") // collision uses prefix
assert.Contains(t, toolNames, "get_schema") // built-in tool (ExposeSchema=true)
assert.NotContains(t, toolNames, "execute_operation_get_schema") // colliding operation is skipped, not prefixed
})
})

Expand Down
16 changes: 12 additions & 4 deletions router/pkg/mcpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ import (
"go.uber.org/zap"
)

// reservedToolNames contains tool names that are internally registered by the MCP server
// and must not be used by operations when omitToolNamePrefix is enabled.
var reservedToolNames = []string{
"get_schema",
"execute_graphql",
"get_operation_info",
}

// requestHeadersKey is a custom context key for storing request headers.
type requestHeadersKey struct{}

Expand Down Expand Up @@ -394,6 +402,7 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error {
}

s.server.DeleteTools(s.registeredTools...)
s.registeredTools = nil

if err := s.registerTools(); err != nil {
return fmt.Errorf("failed to register tools: %w", err)
Expand Down Expand Up @@ -540,13 +549,12 @@ func (s *GraphQLSchemaServer) registerTools() error {
toolName := operationToolName
if !s.omitToolNamePrefix {
toolName = fmt.Sprintf("execute_operation_%s", operationToolName)
} else if slices.Contains(s.registeredTools, operationToolName) {
s.logger.Warn("Operation name collides with built-in MCP tool, using prefixed name",
} else if slices.Contains(s.registeredTools, operationToolName) || slices.Contains(reservedToolNames, operationToolName) {
s.logger.Error("Skipping operation due to tool name collision",
zap.String("operation", op.Name),
zap.String("conflicting_tool", operationToolName),
zap.String("using_name", fmt.Sprintf("execute_operation_%s", operationToolName)),
)
Comment thread
asoorm marked this conversation as resolved.
toolName = fmt.Sprintf("execute_operation_%s", operationToolName)
continue
}
tool := mcp.NewToolWithRawSchema(
toolName,
Expand Down
194 changes: 194 additions & 0 deletions router/pkg/mcpserver/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package mcpserver

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
"github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)

const testSchema = `
schema {
query: Query
}

type Query {
employee(id: ID!): Employee
employees: [Employee!]!
}

type Employee {
id: ID!
name: String!
}
`

const findEmployeeOp = `
query FindEmployee($id: ID!) {
employee(id: $id) {
id
name
}
}
`

const listEmployeesOp = `
query ListEmployees {
employees {
id
name
}
}
`

const getOperationInfoOp = `
query GetOperationInfo {
employees {
id
name
}
}
`

func writeOperationFiles(t *testing.T, dir string, files map[string]string) {
t.Helper()
for filename, content := range files {
err := os.WriteFile(filepath.Join(dir, filename), []byte(content), 0644)
require.NoError(t, err)
}
}

func TestReload_NoToolDuplication(t *testing.T) {
core, logs := observer.New(zapcore.DebugLevel)
logger := zap.New(core)

tempDir := t.TempDir()
writeOperationFiles(t, tempDir, map[string]string{
"FindEmployee.graphql": findEmployeeOp,
"ListEmployees.graphql": listEmployeesOp,
})

schemaDoc, report := astparser.ParseGraphqlDocumentString(testSchema)
require.False(t, report.HasErrors())
err := asttransform.MergeDefinitionWithBaseSchema(&schemaDoc)
require.NoError(t, err)

srv, err := NewGraphQLSchemaServer(
"http://localhost:4000/graphql",
WithLogger(logger),
WithOperationsDir(tempDir),
WithOmitToolNamePrefix(true),
)
require.NoError(t, err)

// First load
err = srv.Reload(&schemaDoc)
require.NoError(t, err)

firstLoadTools := make([]string, len(srv.registeredTools))
copy(firstLoadTools, srv.registeredTools)

// Second load (simulates config reload)
err = srv.Reload(&schemaDoc)
require.NoError(t, err)

// registeredTools should be identical after reload — no duplicates
assert.Equal(t, firstLoadTools, srv.registeredTools,
"registered tools should be identical after reload, no duplicates")

// Verify no collision errors were logged
collisionLogs := logs.FilterMessage("Skipping operation due to tool name collision")
assert.Equal(t, 0, collisionLogs.Len(),
"no tool name collision errors should be logged on reload")
}

func TestReload_ReservedToolNameCollision(t *testing.T) {
core, logs := observer.New(zapcore.DebugLevel)
logger := zap.New(core)

// Create an operation whose snake_case name will be "get_operation_info",
// which collides with the reserved tool name.
tempDir := t.TempDir()
writeOperationFiles(t, tempDir, map[string]string{
"GetOperationInfo.graphql": getOperationInfoOp,
"ListEmployees.graphql": listEmployeesOp,
})

schemaDoc, report := astparser.ParseGraphqlDocumentString(testSchema)
require.False(t, report.HasErrors())
err := asttransform.MergeDefinitionWithBaseSchema(&schemaDoc)
require.NoError(t, err)

srv, err := NewGraphQLSchemaServer(
"http://localhost:4000/graphql",
WithLogger(logger),
WithOperationsDir(tempDir),
WithOmitToolNamePrefix(true),
)
require.NoError(t, err)

err = srv.Reload(&schemaDoc)
require.NoError(t, err)

// The operation "GetOperationInfo" (snake: "get_operation_info") should be skipped
// because it collides with the reserved tool name.
collisionLogs := logs.FilterMessage("Skipping operation due to tool name collision")
assert.Equal(t, 1, collisionLogs.Len(),
"expected exactly one collision error for reserved tool name")

if collisionLogs.Len() > 0 {
entry := collisionLogs.All()[0]
assert.Equal(t, zapcore.ErrorLevel, entry.Level)
assert.Equal(t, "get_operation_info", entry.ContextMap()["conflicting_tool"])
}

assert.ElementsMatch(t, []string{"get_schema", "list_employees", "get_operation_info"}, srv.registeredTools)
}

func TestReload_PrefixModeAvoidsReservedNameCollision(t *testing.T) {
core, logs := observer.New(zapcore.DebugLevel)
logger := zap.New(core)

// "GetOperationInfo" snake_cases to "get_operation_info" which is a reserved name.
// With the prefix enabled, it becomes "execute_operation_get_operation_info" and no collision occurs.
tempDir := t.TempDir()
writeOperationFiles(t, tempDir, map[string]string{
"GetOperationInfo.graphql": getOperationInfoOp,
"ListEmployees.graphql": listEmployeesOp,
})

schemaDoc, report := astparser.ParseGraphqlDocumentString(testSchema)
require.False(t, report.HasErrors())
err := asttransform.MergeDefinitionWithBaseSchema(&schemaDoc)
require.NoError(t, err)

srv, err := NewGraphQLSchemaServer(
"http://localhost:4000/graphql",
WithLogger(logger),
WithOperationsDir(tempDir),
WithOmitToolNamePrefix(false),
)
require.NoError(t, err)

err = srv.Reload(&schemaDoc)
require.NoError(t, err)

// No collisions because the prefix disambiguates from the reserved name
collisionLogs := logs.FilterMessage("Skipping operation due to tool name collision")
assert.Equal(t, 0, collisionLogs.Len(),
"no collisions expected with tool name prefix enabled")

assert.ElementsMatch(t, []string{
"get_schema",
"execute_operation_get_operation_info",
"execute_operation_list_employees",
"get_operation_info",
}, srv.registeredTools)
}
Loading