diff --git a/lib/client/db/mcp/server.go b/lib/client/db/mcp/server.go index d27270d54ef77..19924b9a289c1 100644 --- a/lib/client/db/mcp/server.go +++ b/lib/client/db/mcp/server.go @@ -111,7 +111,7 @@ func (s *RootServer) RegisterDatabase(db *Database) { s.mu.Lock() defer s.mu.Unlock() - uri := db.ResourceURI().String() + uri := db.ResourceURI().WithoutParams().String() s.availableDatabases[uri] = db s.AddResource(mcp.NewResource(uri, fmt.Sprintf("%s Database", db.DB.GetName()), mcp.WithMIMEType(databaseResourceMIMEType)), s.GetDatabaseResource) } @@ -124,7 +124,7 @@ func (s *RootServer) ServeStdio(ctx context.Context, in io.Reader, out io.Writer func buildDatabaseResource(db *Database) DatabaseResource { return DatabaseResource{ Metadata: db.DB.GetMetadata(), - URI: db.ResourceURI().String(), + URI: db.ResourceURI().WithoutParams().String(), Protocol: db.DB.GetProtocol(), ClusterName: db.ClusterName, } diff --git a/lib/client/db/mcp/server_test.go b/lib/client/db/mcp/server_test.go index 007067d3507ea..f581f7462b1d1 100644 --- a/lib/client/db/mcp/server_test.go +++ b/lib/client/db/mcp/server_test.go @@ -42,7 +42,7 @@ func TestRegisterDatabase(t *testing.T) { } // sort databases by name to ensure the same order every test. slices.SortFunc(databases, func(a, b *Database) int { - return strings.Compare(a.ResourceURI().String(), b.ResourceURI().String()) + return strings.Compare(a.ResourceURI().WithoutParams().String(), b.ResourceURI().WithoutParams().String()) }) for _, db := range databases { diff --git a/lib/client/db/postgres/mcp/mcp.go b/lib/client/db/postgres/mcp/mcp.go index 6a6b8a19dc10b..1624d11ba4a76 100644 --- a/lib/client/db/postgres/mcp/mcp.go +++ b/lib/client/db/postgres/mcp/mcp.go @@ -77,7 +77,7 @@ func NewServer(ctx context.Context, cfg *dbmcp.NewServerConfig) (dbmcp.Server, e return nil, trace.BadParameter("failed to parse database %q connection config: %s", db.DB.GetName(), err) } - s.databases[db.ResourceURI().String()] = &database{ + s.databases[db.ResourceURI().WithoutParams().String()] = &database{ Database: db, pool: pool, } diff --git a/lib/client/db/postgres/mcp/mcp_test.go b/lib/client/db/postgres/mcp/mcp_test.go index 91debaf500cd1..db959c2ace1fb 100644 --- a/lib/client/db/postgres/mcp/mcp_test.go +++ b/lib/client/db/postgres/mcp/mcp_test.go @@ -85,7 +85,7 @@ func TestFormatErrors(t *testing.T) { URI: "localhost:5432", }) require.NoError(t, err) - dbURI := clientmcp.NewDatabaseResourceURI("root", dbName).String() + dbURI := clientmcp.NewDatabaseResourceURI("root", dbName).WithoutParams().String() for name, tc := range map[string]struct { databaseURI string diff --git a/lib/client/mcp/claude/config.go b/lib/client/mcp/claude/config.go index 90578db684838..e462621d3b2d1 100644 --- a/lib/client/mcp/claude/config.go +++ b/lib/client/mcp/claude/config.go @@ -136,10 +136,20 @@ func (c *Config) GetMCPServers() map[string]MCPServer { return maps.Clone(c.mcpServers) } -// PutMCPServer adds a new MCP server or replace an existing one. +// PutMCPServer adds a new MCP server or replaces an existing one. func (c *Config) PutMCPServer(serverName string, server MCPServer) (err error) { c.mcpServers[serverName] = server - c.configData, err = sjson.SetBytes(c.configData, c.mcpServerJSONPath(serverName), server) + + // We require a custom marshal to improve MCP Resources URI readability when + // it includes query params. By default the encoding/json escapes some + // characters like `&` causing the final URI to be harder to read. + var b bytes.Buffer + enc := json.NewEncoder(&b) + enc.SetEscapeHTML(false) + if err := enc.Encode(server); err != nil { + return trace.Wrap(err) + } + c.configData, err = sjson.SetRawBytes(c.configData, c.mcpServerJSONPath(serverName), b.Bytes()) return trace.Wrap(err) } diff --git a/lib/client/mcp/claude/config_test.go b/lib/client/mcp/claude/config_test.go index 0ced8a55a5942..51f4587fd631d 100644 --- a/lib/client/mcp/claude/config_test.go +++ b/lib/client/mcp/claude/config_test.go @@ -248,6 +248,29 @@ func Test_formatJSON(t *testing.T) { } } +// TestPrettyResourceURIs given a MCP server that includes a Resource URI as +// arguments it must encode and output those URIs in a readable format. +func TestReadableResourceURIs(t *testing.T) { + for name, uri := range map[string]string{ + "uri with query params": "teleport://clusters/root/databases/pg", + "uri without query params": "teleport://clusters/root/databases/pg?dbName=postgres&dbUser=readonly", + "random uri with params": "teleport://random?hello=world&random=resource", + } { + t.Run(name, func(t *testing.T) { + config := NewConfig() + mcpServer := MCPServer{ + Command: "command", + Args: []string{uri}, + } + require.NoError(t, config.PutMCPServer("test", mcpServer)) + + var buf bytes.Buffer + require.NoError(t, config.Write(&buf, FormatJSONCompact)) + require.Contains(t, buf.String(), uri) + }) + } +} + func requireFileWithData(t *testing.T, path string, want string) { t.Helper() read, err := os.ReadFile(path) diff --git a/lib/client/mcp/uri.go b/lib/client/mcp/uri.go index a932a67c776ec..c9337096d73a2 100644 --- a/lib/client/mcp/uri.go +++ b/lib/client/mcp/uri.go @@ -69,8 +69,39 @@ func ParseResourceURI(uri string) (*ResourceURI, error) { return &ResourceURI{url: *parsedURL}, nil } -// NewDatabaseResourceURI creates a new database resource URI. -func NewDatabaseResourceURI(cluster, databaseName string) ResourceURI { +// databaseParams represents the connect params for the database resource. +type databaseParams struct { + // user is the user to log in as. + user string + // name is the name to log in to. + name string +} + +// databaseParam is a param function used for setting database connect params. +type databaseParam func(*databaseParams) + +// WithDatabaseUser configures database params with database user. +func WithDatabaseUser(user string) databaseParam { + return func(dp *databaseParams) { + dp.user = user + } +} + +// WithDatabaseUser configures database params with database name. +func WithDatabaseName(name string) databaseParam { + return func(dp *databaseParams) { + dp.name = name + } +} + +// NewDatabaseResourceURI creates a new database resource URI with connect +// params. +func NewDatabaseResourceURI(cluster string, databaseName string, opts ...databaseParam) ResourceURI { + params := &databaseParams{} + for _, opt := range opts { + opt(params) + } + pathWithHost, _ := databaseURITemplate.Build(urlpath.Match{ Params: map[string]string{ "cluster": cluster, @@ -78,10 +109,19 @@ func NewDatabaseResourceURI(cluster, databaseName string) ResourceURI { }, }) + values := url.Values{} + if params.user != "" { + values.Add(databaseUserQueryParamName, params.user) + } + if params.name != "" { + values.Add(databaseNameQueryParamName, params.name) + } + return ResourceURI{ url: url.URL{ - Scheme: resourceScheme, - Path: strings.TrimPrefix(pathWithHost, "/"), + Scheme: resourceScheme, + Path: strings.TrimPrefix(pathWithHost, "/"), + RawQuery: values.Encode(), }, } } @@ -122,12 +162,21 @@ func (u ResourceURI) IsDatabase() bool { return u.GetDatabaseServiceName() != "" } -// String returns the string representation of the resource URI (excluding the -// query params). +// String returns the string representation of the resource URI. func (u ResourceURI) String() string { - c := u.url - c.RawQuery = "" - return c.String() + return u.url.String() +} + +// WithoutParams returns a copy of the resource without additional parameters. +func (u ResourceURI) WithoutParams() ResourceURI { + copyURL := u.url + copyURL.RawQuery = "" + return ResourceURI{url: copyURL} +} + +// Equal returns true if both resources represent the same Teleport resource. +func (u ResourceURI) Equal(b ResourceURI) bool { + return u.String() == b.String() } // path returns the resource URI full path. We must include the hostname as the diff --git a/lib/client/mcp/uri_test.go b/lib/client/mcp/uri_test.go index 30cdc64a2a357..cf8c3df4f5d72 100644 --- a/lib/client/mcp/uri_test.go +++ b/lib/client/mcp/uri_test.go @@ -57,8 +57,16 @@ func TestDatabaseResourceURI(t *testing.T) { expectedDatabaseUser: "", expectedClusterName: "default", }, - "generated uri": { - uri: NewDatabaseResourceURI("default", "db").String(), + "generated uri with params": { + uri: NewDatabaseResourceURI("default", "db", WithDatabaseUser("user"), WithDatabaseName("name")).String(), + expectedDatabase: true, + expectedServiceName: "db", + expectedDatabaseName: "name", + expectedDatabaseUser: "user", + expectedClusterName: "default", + }, + "generated uri without params": { + uri: NewDatabaseResourceURI("default", "db", WithDatabaseUser("user"), WithDatabaseName("name")).WithoutParams().String(), expectedDatabase: true, expectedServiceName: "db", expectedDatabaseName: "", @@ -92,3 +100,45 @@ func TestDatabaseResourceURI(t *testing.T) { }) } } + +func TestEqualResourceURI(t *testing.T) { + randomType, err := ParseResourceURI("teleport://random/name") + require.NoError(t, err) + + for name, tc := range map[string]struct { + a ResourceURI + b ResourceURI + expectedResult bool + }{ + "same resources": { + a: NewDatabaseResourceURI("cluster", "pg"), + b: NewDatabaseResourceURI("cluster", "pg"), + expectedResult: true, + }, + "same resources, different params": { + a: NewDatabaseResourceURI("cluster", "pg", WithDatabaseUser("readonly"), WithDatabaseName("postgres")).WithoutParams(), + b: NewDatabaseResourceURI("cluster", "pg", WithDatabaseUser("rw"), WithDatabaseName("random")).WithoutParams(), + expectedResult: true, + }, + "same resource type, different resources": { + a: NewDatabaseResourceURI("cluster", "pg"), + b: NewDatabaseResourceURI("cluster", "random"), + expectedResult: false, + }, + "different resource type, same name": { + a: *randomType, + b: NewDatabaseResourceURI("cluster", "pg"), + expectedResult: false, + }, + "same resources compare params": { + a: NewDatabaseResourceURI("cluster", "pg", WithDatabaseUser("rw"), WithDatabaseName("postgres")), + b: NewDatabaseResourceURI("cluster", "pg", WithDatabaseUser("rw"), WithDatabaseName("postgres")), + expectedResult: true, + }, + } { + t.Run(name, func(t *testing.T) { + require.Equal(t, tc.expectedResult, tc.a.Equal(tc.b)) + require.Equal(t, tc.expectedResult, tc.b.Equal(tc.a)) + }) + } +} diff --git a/tool/tsh/common/help.go b/tool/tsh/common/help.go index 6ae4a5e8411e5..092b78279f238 100644 --- a/tool/tsh/common/help.go +++ b/tool/tsh/common/help.go @@ -87,4 +87,16 @@ Examples: Search MCP servers with labels and add to the specified JSON file $ tsh mcp config --labels env=dev --client-config=my-config.json` + + mcpDBConfigHelp = ` +Examples: + Print sample configuration for exposing database as MCP server + $ tsh mcp db config --db-user=mydbuser --db-name=mydbname my-db-resource + + Add the database configuration to Claude Desktop + $ tsh mcp db config --db-user=mydbuser --db-name=mydbname --client-config=claude my-db-resource + + Add the database configuration to the specified JSON file + $ tsh mcp db config --db-user=mydbuser --db-name=mydbname --client-config=my-config.json my-db-resource +` ) diff --git a/tool/tsh/common/mcp.go b/tool/tsh/common/mcp.go index e9675ee9a2a1e..a24c0b2cb2bc0 100644 --- a/tool/tsh/common/mcp.go +++ b/tool/tsh/common/mcp.go @@ -30,7 +30,8 @@ import ( ) type mcpCommands struct { - dbStart *mcpDBStartCommand + dbStart *mcpDBStartCommand + dbConfig *mcpDBConfigCommand config *mcpConfigCommand list *mcpListCommand @@ -41,7 +42,8 @@ func newMCPCommands(app *kingpin.Application, cf *CLIConf) *mcpCommands { mcp := app.Command("mcp", "View and control proxied MCP servers.") db := mcp.Command("db", "Database access for MCP servers.") return &mcpCommands{ - dbStart: newMCPDBCommand(db), + dbStart: newMCPDBCommand(db, cf), + dbConfig: newMCPDBconfigCommand(db, cf), list: newMCPListCommand(mcp, cf), config: newMCPConfigCommand(mcp, cf), @@ -114,6 +116,7 @@ a config file compatible with the "mcpServer" mapping.`) // claudeConfig defines a subset of functions from claude.Config. type claudeConfig interface { PutMCPServer(string, claude.MCPServer) error + GetMCPServers() map[string]claude.MCPServer } func makeLocalMCPServer(cf *CLIConf, args []string) claude.MCPServer { diff --git a/tool/tsh/common/mcp_db.go b/tool/tsh/common/mcp_db.go index 87627ebf3dd2c..a47bf0eeaaad3 100644 --- a/tool/tsh/common/mcp_db.go +++ b/tool/tsh/common/mcp_db.go @@ -18,15 +18,22 @@ package common import ( "context" + "fmt" "log/slog" + "maps" + "text/template" "github.com/alecthomas/kingpin/v2" "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/client/proto" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/client" dbmcp "github.com/gravitational/teleport/lib/client/db/mcp" pgmcp "github.com/gravitational/teleport/lib/client/db/postgres/mcp" "github.com/gravitational/teleport/lib/client/mcp" + "github.com/gravitational/teleport/lib/client/mcp/claude" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/tlsca" @@ -38,30 +45,32 @@ import ( type mcpDBStartCommand struct { *kingpin.CmdClause + cf *CLIConf databaseURIs []string } -func newMCPDBCommand(parent *kingpin.CmdClause) *mcpDBStartCommand { +func newMCPDBCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpDBStartCommand { cmd := &mcpDBStartCommand{ CmdClause: parent.Command("start", "Start a local MCP server for database access.").Hidden(), + cf: cf, } cmd.Arg("uris", "List of database MCP resource URIs that will be served by the server.").Required().StringsVar(&cmd.databaseURIs) return cmd } -func (c *mcpDBStartCommand) run(cf *CLIConf) error { - logger, err := initLogger(cf, utils.LoggingForMCP, getLoggingOptsForMCPServer(cf)) +func (c *mcpDBStartCommand) run() error { + logger, err := initLogger(c.cf, utils.LoggingForMCP, getLoggingOptsForMCPServer(c.cf)) if err != nil { return trace.Wrap(err) } registry := defaultDBMCPRegistry - if cf.databaseMCPRegistryOverride != nil { - registry = cf.databaseMCPRegistryOverride + if c.cf.databaseMCPRegistryOverride != nil { + registry = c.cf.databaseMCPRegistryOverride } - tc, err := makeClient(cf) + tc, err := makeClient(c.cf) if err != nil { return trace.Wrap(err) } @@ -87,16 +96,16 @@ func (c *mcpDBStartCommand) run(cf *CLIConf) error { return trace.BadParameter("Databases must be from the same cluster (%q). %q is from a different cluster.", tc.SiteName, rawURI) } - if _, ok := configuredDatabases[uri.String()]; ok { - return trace.BadParameter("Database %q was configured twice. MCP servers only support serving a database service only once.", uri.String()) + if _, ok := configuredDatabases[uri.WithoutParams().String()]; ok { + return trace.BadParameter("Database %q was configured twice. MCP servers only support serving a database service only once.", uri.GetDatabaseName()) } - configuredDatabases[uri.String()] = struct{}{} + configuredDatabases[uri.WithoutParams().String()] = struct{}{} uris[i] = uri } server := dbmcp.NewRootServer(logger) - allDatabases, closeLocalProxies, err := c.prepareDatabases(cf, tc, registry, uris, logger, server) + allDatabases, closeLocalProxies, err := c.prepareDatabases(c.cf, tc, registry, uris, logger, server) if err != nil { return trace.Wrap(err) } @@ -108,7 +117,7 @@ func (c *mcpDBStartCommand) run(cf *CLIConf) error { continue } - srv, err := newServerFunc(cf.Context, &dbmcp.NewServerConfig{ + srv, err := newServerFunc(c.cf.Context, &dbmcp.NewServerConfig{ Logger: logger, RootServer: server, Databases: databases, @@ -116,10 +125,10 @@ func (c *mcpDBStartCommand) run(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - defer srv.Close(cf.Context) + defer srv.Close(c.cf.Context) } - return trace.Wrap(server.ServeStdio(cf.Context, cf.Stdin(), cf.Stdout())) + return trace.Wrap(server.ServeStdio(c.cf.Context, c.cf.Stdin(), c.cf.Stdout())) } // closeLocalProxyFunc function used to close local proxy listeners. @@ -226,9 +235,220 @@ func (c *mcpDBStartCommand) prepareDatabases( }, nil } +// databasesGetter is the interface used to retrieve available +// databases using filters. +type databasesGetter interface { + // ListDatabases returns all registered databases. + ListDatabases(ctx context.Context, customFilter *proto.ListResourcesRequest) ([]types.Database, error) +} + +// mcpDBConfigCommand implements `tsh mcp db config` command. +type mcpDBConfigCommand struct { + *kingpin.CmdClause + + clientConfig mcpClientConfigFlags + ctx context.Context + cf *CLIConf + siteName string + overwriteEnv bool + + // databasesGetter used to retrieve databases information. Can be mocked in + // tests. + databasesGetter databasesGetter +} + +func newMCPDBconfigCommand(parent *kingpin.CmdClause, cf *CLIConf) *mcpDBConfigCommand { + cmd := &mcpDBConfigCommand{ + CmdClause: parent.Command("config", "Print client configuration details."), + ctx: cf.Context, + cf: cf, + } + + cmd.Flag("db-user", "Database user to log in as.").Short('u').StringVar(&cf.DatabaseUser) + cmd.Flag("db-name", "Database name to log in to.").Short('n').StringVar(&cf.DatabaseName) + cmd.Flag("overwrite", "Overwrites command and environment variable from the config file.").BoolVar(&cmd.overwriteEnv) + cmd.Arg("name", "Database service name.").StringVar(&cf.DatabaseService) + cmd.clientConfig.addToCmd(cmd.CmdClause) + cmd.Alias(mcpDBConfigHelp) + return cmd +} + +// TODO(gabrielcorado): support generating config for multiple databases at once. +func (m *mcpDBConfigCommand) run() error { + if m.databasesGetter == nil { + tc, err := makeClient(m.cf) + if err != nil { + return trace.Wrap(err) + } + + m.databasesGetter = tc + m.siteName = tc.SiteName + } + + databases, err := m.databasesGetter.ListDatabases(m.ctx, &proto.ListResourcesRequest{ + Namespace: apidefaults.Namespace, + ResourceType: types.KindDatabaseServer, + PredicateExpression: makeDiscoveredNameOrNamePredicate(m.cf.DatabaseService), + // TODO(gabrielcorado): support requesting access. + UseSearchAsRoles: false, + }) + if err != nil { + return trace.Wrap(err) + } + + db, err := chooseOneDatabase(m.cf, databases) + if err != nil { + return trace.Wrap(err) + } + + // TODO(gabrielcorado): support having the flags empty and assume the values + // based on the role and database. + if m.cf.DatabaseUser == "" || m.cf.DatabaseName == "" { + return trace.BadParameter("You must specify --db-user and --db-name flags used to connect to the database") + } + + dbURI := mcp.NewDatabaseResourceURI(m.siteName, db.GetName(), mcp.WithDatabaseUser(m.cf.DatabaseUser), mcp.WithDatabaseName(m.cf.DatabaseName)) + switch { + case m.clientConfig.isSet(): + return trace.Wrap(m.updateClientConfig(dbURI)) + default: + return trace.Wrap(m.printJSONWithHint(dbURI)) + } +} + +func (m *mcpDBConfigCommand) printJSONWithHint(dbURI mcp.ResourceURI) error { + config := claude.NewConfig() + // Since the database is being added to a "fresh" config file the database + // will always be new and we can ignore the additional message as well. + if _, _, err := m.addDatabaseToConfig(config, dbURI); err != nil { + return trace.Wrap(err) + } + + w := m.cf.Stdout() + if _, err := fmt.Fprintln(w, "Here is a sample JSON configuration for launching Teleport MCP servers:"); err != nil { + return trace.Wrap(err) + } + if err := config.Write(w, claude.FormatJSONOption(m.clientConfig.jsonFormat)); err != nil { + return trace.Wrap(err) + } + if _, err := fmt.Fprintf(w, ` +If you already have an entry for %q server, add the following database resource URI to the command arguments list: +%s + +`, mcpDBConfigName, dbURI.String()); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(m.clientConfig.printHint(w)) +} + +// TODO(gabrielcorado): support updating multiple databases at once. +func (m *mcpDBConfigCommand) updateClientConfig(dbURI mcp.ResourceURI) error { + config, err := m.clientConfig.loadConfig() + if err != nil { + return trace.Wrap(err) + } + preexistentDB, commandChanged, err := m.addDatabaseToConfig(config, dbURI) + if err != nil { + return trace.Wrap(err) + } + + if err := config.Save(claude.FormatJSONOption(m.clientConfig.jsonFormat)); err != nil { + return trace.Wrap(err) + } + + templateData := struct { + Name string + ConfigPath string + ConfigName string + PreexistentDB bool + EnvChanged bool + OverwriteEnv bool + }{ + Name: dbURI.GetDatabaseServiceName(), + ConfigPath: config.Path(), + ConfigName: mcpDBConfigName, + PreexistentDB: preexistentDB, + EnvChanged: commandChanged, + OverwriteEnv: m.overwriteEnv, + } + + return trace.Wrap(mcpDBConfigMessageTemplate.Execute(m.cf.Stdout(), templateData)) +} + +// addDatabaseToConfig adds the provided database, merging with existent +// databases configured. This function returns a additional message to be +// displayed to users. +func (m *mcpDBConfigCommand) addDatabaseToConfig(config claudeConfig, dbURI mcp.ResourceURI) (bool, bool, error) { + var ( + dbs []string + updated bool + envChanged bool + server = makeLocalMCPServer(m.cf, nil /* args */) + ) + if existentServer, ok := config.GetMCPServers()[mcpDBConfigName]; ok { + // For most common cases we want to keep the environment variables + // unchanged. However, in case users want a "fresh start" they can + // provide a flag so we overwrite them with default values. + if !maps.Equal(server.Envs, existentServer.Envs) { + envChanged = true + if !m.overwriteEnv { + server.Envs = existentServer.Envs + } + } + + for _, arg := range existentServer.Args { + // We're only interested in resources, any flags or other command + // parts will be discarded. + uri, err := mcp.ParseResourceURI(arg) + if err != nil { + continue + } + + if !uri.IsDatabase() { + return false, false, trace.BadParameter("resource %q on config is not a database", uri.String()) + } + + if uri.WithoutParams().Equal(dbURI.WithoutParams()) { + dbs = append(dbs, dbURI.String()) + updated = true + } else { + dbs = append(dbs, uri.String()) + } + } + } + + if !updated { + dbs = append(dbs, dbURI.String()) + } + + server.Args = append([]string{"mcp", "db", "start"}, dbs...) + return updated, envChanged, trace.Wrap(config.PutMCPServer(mcpDBConfigName, server)) +} + var ( // defaultDBMCPRegistry is the default database access MCP servers registry. defaultDBMCPRegistry = map[string]dbmcp.NewServerFunc{ defaults.ProtocolPostgres: pgmcp.NewServer, } ) + +// mcpDBConfigName is the configuration name that is managed by the config +// command. +const mcpDBConfigName = "teleport-databases" + +// mcpDBConfigMessageTemplate is the MCP db config message template. +var mcpDBConfigMessageTemplate = template.Must(template.New("").Funcs(template.FuncMap{ + "quote": func(s string) string { return fmt.Sprintf("%q", s) }, +}).Parse(`{{ if .PreexistentDB -}}Updated{{ else }}Added{{ end }} database {{ .Name | quote }} on the client configuration at: +{{ .ConfigPath }} + +Teleport database access MCP server is named {{ .ConfigName | quote }} in this configuration. + +You may need to restart your client to reload these new configurations. + +{{- if (and (.EnvChanged) (not .OverwriteEnv)) }} + +Environment variables have changed, but existing values will be preserved. +To overwrite them, rerun this command with the --overwrite flag. +{{- end }} +`)) diff --git a/tool/tsh/common/mcp_db_test.go b/tool/tsh/common/mcp_db_test.go index 123657d212841..5faf7c61ad479 100644 --- a/tool/tsh/common/mcp_db_test.go +++ b/tool/tsh/common/mcp_db_test.go @@ -17,19 +17,27 @@ package common import ( + "bytes" "context" "io" + "path/filepath" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" mcpclient "github.com/mark3labs/mcp-go/client" mcptransport "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" dbmcp "github.com/gravitational/teleport/lib/client/db/mcp" + clientmcp "github.com/gravitational/teleport/lib/client/mcp" + "github.com/gravitational/teleport/lib/client/mcp/claude" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/service/servicecfg" testserver "github.com/gravitational/teleport/tool/teleport/testenv" @@ -217,7 +225,199 @@ func TestMCPDBCommandFailures(t *testing.T) { }) } +func TestMCPDBConfigCommand(t *testing.T) { + clusterName := "root" + db0, err := types.NewDatabaseV3(types.Metadata{ + Name: "pg", + }, types.DatabaseSpecV3{ + Protocol: "protocol", + URI: "uri", + }) + require.NoError(t, err) + db1, err := types.NewDatabaseV3(types.Metadata{ + Name: "another", + }, types.DatabaseSpecV3{ + Protocol: "protocol", + URI: "uri", + }) + require.NoError(t, err) + + dbURI0 := clientmcp.NewDatabaseResourceURI(clusterName, db0.GetName(), clientmcp.WithDatabaseUser("readonly"), clientmcp.WithDatabaseName("dbname")) + dbURI0Updated := clientmcp.NewDatabaseResourceURI(clusterName, db0.GetName(), clientmcp.WithDatabaseUser("rw"), clientmcp.WithDatabaseName("anotherdb")) + dbURI1 := clientmcp.NewDatabaseResourceURI(clusterName, db1.GetName(), clientmcp.WithDatabaseUser("rw"), clientmcp.WithDatabaseName("dbname")) + + for name, tc := range map[string]struct { + cf *CLIConf + overwriteEnv bool + databasesGetter databasesGetter + assertError require.ErrorAssertionFunc + initialDatabases []string + expectedDatabases []string + initialEnv map[string]string + expectedEnv map[string]string + }{ + "add database to empty config": { + cf: &CLIConf{ + DatabaseService: dbURI0.GetDatabaseServiceName(), + DatabaseUser: dbURI0.GetDatabaseUser(), + DatabaseName: dbURI0.GetDatabaseName(), + }, + databasesGetter: &mockDatabasesGetter{dbs: []types.Database{db0, db1}}, + assertError: require.NoError, + expectedDatabases: []string{dbURI0.String()}, + expectedEnv: map[string]string{}, + }, + "append database to config": { + cf: &CLIConf{ + DatabaseService: dbURI1.GetDatabaseServiceName(), + DatabaseUser: dbURI1.GetDatabaseUser(), + DatabaseName: dbURI1.GetDatabaseName(), + }, + databasesGetter: &mockDatabasesGetter{dbs: []types.Database{db0, db1}}, + assertError: require.NoError, + initialDatabases: []string{dbURI0.String()}, + expectedDatabases: []string{dbURI0.String(), dbURI1.String()}, + expectedEnv: map[string]string{}, + }, + "update existent database": { + cf: &CLIConf{ + DatabaseService: dbURI0Updated.GetDatabaseServiceName(), + DatabaseUser: dbURI0Updated.GetDatabaseUser(), + DatabaseName: dbURI0Updated.GetDatabaseName(), + }, + databasesGetter: &mockDatabasesGetter{dbs: []types.Database{db0, db1}}, + assertError: require.NoError, + initialDatabases: []string{dbURI0.String(), dbURI1.String()}, + expectedDatabases: []string{dbURI0Updated.String(), dbURI1.String()}, + expectedEnv: map[string]string{}, + }, + "database not found": { + cf: &CLIConf{ + DatabaseService: dbURI0.GetDatabaseServiceName(), + DatabaseUser: dbURI0.GetDatabaseUser(), + DatabaseName: dbURI0.GetDatabaseName(), + }, + databasesGetter: &mockDatabasesGetter{err: trace.NotFound("database not found")}, + assertError: require.Error, + }, + "missing connection params": { + cf: &CLIConf{ + DatabaseService: dbURI0Updated.GetDatabaseServiceName(), + }, + databasesGetter: &mockDatabasesGetter{dbs: []types.Database{db0}}, + assertError: require.Error, + }, + "keep current environment setting": { + cf: &CLIConf{ + DatabaseService: dbURI0.GetDatabaseServiceName(), + DatabaseUser: dbURI0.GetDatabaseUser(), + DatabaseName: dbURI0.GetDatabaseName(), + DebugSetByUser: true, + Debug: true, + }, + databasesGetter: &mockDatabasesGetter{dbs: []types.Database{db0, db1}}, + assertError: require.NoError, + initialDatabases: []string{dbURI0.String()}, + expectedDatabases: []string{dbURI0.String()}, + initialEnv: map[string]string{"test": "hello"}, + expectedEnv: map[string]string{"test": "hello"}, + }, + "reset environment setting": { + cf: &CLIConf{ + DatabaseService: dbURI0.GetDatabaseServiceName(), + DatabaseUser: dbURI0.GetDatabaseUser(), + DatabaseName: dbURI0.GetDatabaseName(), + }, + overwriteEnv: true, + databasesGetter: &mockDatabasesGetter{dbs: []types.Database{db0, db1}}, + assertError: require.NoError, + initialDatabases: []string{dbURI0.String()}, + expectedDatabases: []string{dbURI0.String()}, + initialEnv: map[string]string{"test": "hello"}, + expectedEnv: map[string]string{}, + }, + } { + t.Run(name, func(t *testing.T) { + configPath := setupMockDBMCPConfig(t, tc.cf, tc.initialDatabases, tc.initialEnv) + var buf bytes.Buffer + tc.cf.Context = context.Background() + tc.cf.Proxy = "proxy:3080" + tc.cf.HomePath = t.TempDir() + tc.cf.OverrideStdout = &buf + mustCreateEmptyProfile(t, tc.cf) + + cmd := &mcpDBConfigCommand{ + clientConfig: mcpClientConfigFlags{ + clientConfig: configPath, + jsonFormat: string(claude.FormatJSONPretty), + }, + cf: tc.cf, + ctx: t.Context(), + siteName: clusterName, + databasesGetter: tc.databasesGetter, + overwriteEnv: tc.overwriteEnv, + } + + err := cmd.run() + tc.assertError(t, err) + if err != nil { + return + } + + jsonConfig, err := claude.LoadConfigFromFile(configPath) + require.NoError(t, err) + mcpCmd, ok := jsonConfig.GetMCPServers()[mcpDBConfigName] + require.True(t, ok, "expected configuration to include database access server definition, but got nothing") + require.Empty(t, cmp.Diff(mcpCmd.Args, tc.expectedDatabases, cmpopts.EquateEmpty(), cmpopts.IgnoreSliceElements(func(arg string) bool { + // Only assert database resources on the args. + _, err := clientmcp.ParseResourceURI(arg) + return err != nil + }))) + require.Empty(t, cmp.Diff(mcpCmd.Envs, tc.expectedEnv, cmpopts.EquateEmpty(), cmpopts.IgnoreMapEntries(func(key string, _ string) bool { + // Ignore default fields, only look for additional ones. + switch key { + case types.HomeEnvVar, debugEnvVar, osLogEnvVar: + return true + default: + return false + } + }))) + }) + } +} + +func setupMockDBMCPConfig(t *testing.T, cf *CLIConf, databasesURIs []string, additionalEnv map[string]string) string { + t.Helper() + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + config, err := claude.LoadConfigFromFile(configPath) + require.NoError(t, err) + require.NoError(t, config.PutMCPServer("local-everything", claude.MCPServer{ + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-everything"}, + })) + if len(databasesURIs) > 0 { + srv := makeLocalMCPServer(cf, append([]string{"mcp", "db", "start"}, databasesURIs...)) + for name, value := range additionalEnv { + srv.AddEnv(name, value) + } + require.NoError(t, config.PutMCPServer(mcpDBConfigName, srv)) + } + require.NoError(t, config.Save(claude.FormatJSONPretty)) + return config.Path() +} + // testDatabaseMCP is a noop database MCP server. type testDatabaseMCP struct{} func (s *testDatabaseMCP) Close(_ context.Context) error { return nil } + +// mockDatabaseGetter is a fetch databases mock. +type mockDatabasesGetter struct { + dbs []types.Database + err error +} + +func (m *mockDatabasesGetter) ListDatabases(_ context.Context, _ *proto.ListResourcesRequest) ([]types.Database, error) { + return m.dbs, m.err +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index d76fed163de01..e5d37a556cb39 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -1804,7 +1804,9 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { case pivCmd.agent.FullCommand(): err = pivCmd.agent.run(&cf) case mcpCmd.dbStart.FullCommand(): - err = mcpCmd.dbStart.run(&cf) + err = mcpCmd.dbStart.run() + case mcpCmd.dbConfig.FullCommand(): + err = mcpCmd.dbConfig.run() case mcpCmd.connect.FullCommand(): err = mcpCmd.connect.run() case mcpCmd.list.FullCommand():