diff --git a/lib/client/db/mcp/errors.go b/lib/client/db/mcp/errors.go new file mode 100644 index 0000000000000..8ab0112e9332a --- /dev/null +++ b/lib/client/db/mcp/errors.go @@ -0,0 +1,87 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "errors" + "io" + "strings" + + "github.com/gravitational/trace" + + apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/lib/client/mcp" +) + +// ExtenralErrorRetriever returns an external error that might have happened. +// +// MCP servers don't have knowledge of other processes that might fail during +// their execution, such as authentication failures. This provider can be used +// to give them the necessary context to provide more accurate user messages. +type ExternalErrorRetriever interface { + // RetrieveError retrieves the error if any. + RetrieveError() error +} + +// FormatErrorMessage formats the database MCP error messages. +// format. +func FormatErrorMessage(retreiver ExternalErrorRetriever, err error) error { + if retreiver != nil { + err = trace.NewAggregate(retreiver.RetrieveError(), err) + } + + switch { + case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired): + return trace.BadParameter(ReloginRequiredErrorMessage) + case strings.Contains(err.Error(), "connection reset by peer") || errors.Is(err, io.ErrClosedPipe): + return trace.BadParameter(LocalProxyConnectionErrorMessage) + } + + return err +} + +const ( + // ReloginRequiredErrorMessage is the message returned to the MCP client + // when the tsh session expired. + ReloginRequiredErrorMessage = `It looks like your Teleport session expired, +you must relogin (using "tsh login" on a terminal) before continue using this +tool. After that, there is no need to update or relaunch the MCP client - just +try using it again.` + // LocalProxyConnectionErrorMessage is the message returned to the MCP client when + // the database client cannot connect to the local proxy. + LocalProxyConnectionErrorMessage = `Teleport MCP server is having issue while +establishing the database connection. You can verify the MCP logs for more +details on what is causing this issue. After identifying and fixing the issue +a restart on the MCP client might be necessary.` + // EmptyDatabasesListErrorMessage is the message returned to the MCP client when + // the started database server is serving no databases. + EmptyDatabasesListErrorMessage = `There are no active Teleport databases available +for use on the MCP server. You can check the MCP server logs to see if any +database was not included due to an error. You can also verify that the list +of databases on the MCP command is correct.` +) + +var ( + // WrongDatabaseURIFormatError is the message returned to the MCP client + // when it sends a malformed database resource URI. + WrongDatabaseURIFormatError = trace.BadParameter("Malformed database resource URI. Database resources must follow the format: %q", mcp.SampleDatabaseResource) + // DatabaseNotFoundError is the message returned to the MCP client when the + // requested database is not available as MCP resource. + DatabaseNotFoundError = trace.NotFound(`Database not found. Only registered databases +can be used. Ask the user to attach the database resource or list the available +resources with %q tool`, listDatabasesToolName) +) diff --git a/lib/client/db/mcp/mcp.go b/lib/client/db/mcp/mcp.go new file mode 100644 index 0000000000000..43baf36144747 --- /dev/null +++ b/lib/client/db/mcp/mcp.go @@ -0,0 +1,107 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "context" + "log/slog" + "net" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client/mcp" +) + +// NewServerConfig configuration passed to the server constructors. +type NewServerConfig struct { + Logger *slog.Logger + RootServer *RootServer + Databases []*Database +} + +// NewServerFunc the MCP server constructor function definition. +type NewServerFunc func(context.Context, *NewServerConfig) (Server, error) + +// Server represents a MCP server. +type Server interface { + // Close closes the server. + Close(context.Context) error +} + +// Registry represents the available databases MCP servers per protocol and +// their constructors. +type Registry map[string]NewServerFunc + +// IsSupported returns if the database protocol is supported by any MCP server +// available. +func (m Registry) IsSupported(protocol string) bool { + _, ok := m[protocol] + return ok +} + +// LookupFunc is the function used to resolve database address. Follows the +// net.Resolver.LookupAddr format. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + +// DialContextFunc is a function used to dial the database. Follows the +// net.Dialer.DialContext format. +type DialContextFunc func(ctx context.Context, network string, addr string) (net.Conn, error) + +// Database the database served by an MCP server. +type Database struct { + // DB contains all information from the database. + DB types.Database + // ClusterName is the cluster name where the database is located. + ClusterName string + // Addr is the address the MCP server used to create a new database + // connection. + Addr string + // DatabaseUser is the database username used on the connections. + DatabaseUser string + // DatabaseName is the database name used on the connections. + DatabaseName string + // ExternalErrorRetriever used to retrieve any external error that might + // have happened while connecting/communicating with the database. + ExternalErrorRetriever ExternalErrorRetriever + // LookupFunc is the lookup function to resolve database address. + LookupFunc LookupFunc + // DialContextFunc is the dial function used to connect to the database. + DialContextFunc DialContextFunc +} + +// ResourceURI returns the database MCP resource URI. +func (d Database) ResourceURI() mcp.ResourceURI { + return mcp.NewDatabaseResourceURI(d.ClusterName, d.DB.GetName()) +} + +// DatabaseResource MCP resource representation of a Teleport database. +type DatabaseResource struct { + types.Metadata + // URI is the MCP URI resource. + URI string `json:"uri"` + // Protocol is the database protocol. + Protocol string `json:"protocol"` + // ClusterName is the cluster the database is. + ClusterName string `json:"cluster_name"` +} + +// ToolName generates a database access tool name. +func ToolName(protocol, name string) string { + return ToolPrefix + protocol + "_" + name +} + +// ToolPrefix is the default tool prefix for every MCP tool. +const ToolPrefix = "teleport_" diff --git a/lib/client/db/mcp/server.go b/lib/client/db/mcp/server.go new file mode 100644 index 0000000000000..b26e8bd0eaabe --- /dev/null +++ b/lib/client/db/mcp/server.go @@ -0,0 +1,154 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "context" + "fmt" + "io" + "log/slog" + "sync" + + "github.com/ghodss/yaml" + "github.com/gravitational/trace" + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + + "github.com/gravitational/teleport" +) + +// listDatabasesTool is the MCP tool that list all databases being served +// (from all protocols). +var listDatabasesTool = mcp.NewTool(listDatabasesToolName, + mcp.WithDescription("List database resources available to be used with Teleport tools."), +) + +// RootServer database access root MCP server. It includes common MCP tools and +// resources across different databases and serves as a root server where +// database-specific MCP servers register their tools. +type RootServer struct { + *mcpserver.MCPServer + + mu sync.RWMutex + logger *slog.Logger + availableDatabases map[string]*Database +} + +// NewRootServer initializes a new root MCP server. +func NewRootServer(logger *slog.Logger) *RootServer { + server := &RootServer{ + MCPServer: mcpserver.NewMCPServer(serverName, teleport.Version), + logger: logger, + availableDatabases: make(map[string]*Database), + } + server.AddTool(listDatabasesTool, server.ListDatabases) + + return server +} + +// ListDatabases tool function used to list all available/served databases. +func (s *RootServer) ListDatabases(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if len(s.availableDatabases) == 0 { + return mcp.NewToolResultError(EmptyDatabasesListErrorMessage), nil + } + + var res []mcp.Content + for _, db := range s.availableDatabases { + contents, err := encodeDatabaseResource(db) + if err != nil { + s.logger.ErrorContext(ctx, "error while list databases", "error", err) + return mcp.NewToolResultError(FormatErrorMessage(nil, err).Error()), nil + } + res = append(res, mcp.EmbeddedResource{Type: "resource", Resource: contents}) + } + + return &mcp.CallToolResult{ + Content: res, + }, nil +} + +// GetDatabaseResource resource handler for databases. +func (s *RootServer) GetDatabaseResource(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + db, ok := s.availableDatabases[request.Params.URI] + if !ok { + return nil, trace.NotFound("Database is %q not available as MCP resource", request.Params.URI) + } + + encodedDb, err := encodeDatabaseResource(db) + if err != nil { + return nil, trace.Wrap(err) + } + + return []mcp.ResourceContents{encodedDb}, nil +} + +// RegisterDatabase register a database on the root server. This make it +// available as a MCP resource. +// +// TODO(gabrielcorado): support dynamically registering/deregistering databases +// after the server starts. +func (s *RootServer) RegisterDatabase(db *Database) { + s.mu.Lock() + defer s.mu.Unlock() + + uri := db.ResourceURI().String() + s.availableDatabases[uri] = db + s.AddResource(mcp.NewResource(uri, fmt.Sprintf("%s Datatabase", db.DB.GetName()), mcp.WithMIMEType(databaseResourceMIMEType)), s.GetDatabaseResource) +} + +// ServeStdio starts serving the root MCP using STDIO transport. +func (s *RootServer) ServeStdio(ctx context.Context, in io.Reader, out io.Writer) error { + return trace.Wrap(mcpserver.NewStdioServer(s.MCPServer).Listen(ctx, in, out)) +} + +func buildDatabaseResource(db *Database) DatabaseResource { + return DatabaseResource{ + Metadata: db.DB.GetMetadata(), + URI: db.ResourceURI().String(), + Protocol: db.DB.GetProtocol(), + ClusterName: db.ClusterName, + } +} + +func encodeDatabaseResource(db *Database) (mcp.ResourceContents, error) { + resource := buildDatabaseResource(db) + out, err := yaml.Marshal(resource) + if err != nil { + return nil, trace.Wrap(err) + } + + return mcp.TextResourceContents{ + URI: resource.URI, + MIMEType: databaseResourceMIMEType, + Text: string(out), + }, nil +} + +const ( + // serverName is the database MCP server name. + serverName = "teleport_databases" + // listDatabasesTool is the list databases tool name. + listDatabasesToolName = ToolPrefix + "list_databases" + // databaseResourceMIMEType is the MIME type of the database resources. + databaseResourceMIMEType = "application/yaml" +) diff --git a/lib/client/db/mcp/server_test.go b/lib/client/db/mcp/server_test.go new file mode 100644 index 0000000000000..007067d3507ea --- /dev/null +++ b/lib/client/db/mcp/server_test.go @@ -0,0 +1,176 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "log/slog" + "slices" + "strings" + "testing" + + "github.com/ghodss/yaml" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" +) + +func TestRegisterDatabase(t *testing.T) { + server := NewRootServer(slog.New(slog.DiscardHandler)) + databases := []*Database{ + buildDatabase(t, "first"), + buildDatabase(t, "second"), + buildDatabase(t, "third"), + } + // sort databases by name to ensure the same order every test. + slices.SortFunc(databases, func(a, b *Database) int { + return strings.Compare(a.ResourceURI().String(), b.ResourceURI().String()) + }) + + for _, db := range databases { + server.RegisterDatabase(db) + } + + clt := buildClient(t, server) + t.Run("Resources", func(t *testing.T) { + listResult, err := clt.ListResources(t.Context(), mcp.ListResourcesRequest{}) + require.NoError(t, err) + require.Len(t, listResult.Resources, len(databases)) + + // sort the result using the same field as databases to avoid flaky + // test. + slices.SortFunc(listResult.Resources, func(a, b mcp.Resource) int { + return strings.Compare(a.URI, b.URI) + }) + + for i, r := range listResult.Resources { + req := mcp.ReadResourceRequest{} + req.Params.URI = r.URI + readResult, err := clt.ReadResource(t.Context(), req) + require.NoError(t, err) + require.Len(t, readResult.Contents, 1) + assertDatabaseResource(t, databases[i], readResult.Contents[0]) + } + }) + + t.Run("Tool", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Name = listDatabasesToolName + res, err := clt.CallTool(t.Context(), req) + require.NoError(t, err) + require.False(t, res.IsError) + require.Len(t, res.Content, len(databases)) + + for _, c := range res.Content { + require.IsType(t, mcp.EmbeddedResource{}, c) + require.IsType(t, mcp.TextResourceContents{}, c.(mcp.EmbeddedResource).Resource) + } + + // Although we're not sorting by the URI directly, the only field that + // is different across the databases is their name and URI (which would + // cause them to have the same order). So here we sort by the YAML + // contents to avoid having to decode. + slices.SortFunc(res.Content, func(a, b mcp.Content) int { + resourceA := a.(mcp.EmbeddedResource).Resource.(mcp.TextResourceContents) + resourceB := b.(mcp.EmbeddedResource).Resource.(mcp.TextResourceContents) + return strings.Compare(resourceA.Text, resourceB.Text) + }) + + for i, c := range res.Content { + content := c.(mcp.EmbeddedResource) + assertDatabaseResource(t, databases[i], content.Resource) + } + }) +} + +func TestEmptyDatabasesServer(t *testing.T) { + server := NewRootServer(slog.New(slog.DiscardHandler)) + + clt := buildClient(t, server) + t.Run("Resources", func(t *testing.T) { + _, err := clt.ListResources(t.Context(), mcp.ListResourcesRequest{}) + require.Error(t, err) + }) + + t.Run("Tool", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Name = listDatabasesToolName + res, err := clt.CallTool(t.Context(), req) + require.NoError(t, err) + require.True(t, res.IsError) + + require.Len(t, res.Content, 1) + content := res.Content[0] + require.IsType(t, mcp.TextContent{}, content) + textError := content.(mcp.TextContent).Text + require.Contains(t, textError, EmptyDatabasesListErrorMessage, "expected empty databases error but got: %s", textError) + }) +} + +func assertDatabaseResource(t *testing.T, db *Database, resource mcp.ResourceContents) { + t.Helper() + require.IsType(t, mcp.TextResourceContents{}, resource) + contents := resource.(mcp.TextResourceContents) + var database DatabaseResource + require.Equal(t, databaseResourceMIMEType, contents.MIMEType) + require.NoError(t, yaml.Unmarshal([]byte(contents.Text), &database)) + require.Empty(t, cmp.Diff(buildDatabaseResource(db), database, cmpopts.IgnoreFields(types.Metadata{}, "Namespace"))) +} + +func buildDatabase(t *testing.T, name string) *Database { + t.Helper() + + db, err := types.NewDatabaseV3(types.Metadata{ + Name: name, + Labels: map[string]string{"env": "test"}, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + }) + require.NoError(t, err) + + return &Database{ + DB: db, + ClusterName: "root", + Addr: "localhost:5555", + } +} + +func buildClient(t *testing.T, server *RootServer) *mcpclient.Client { + t.Helper() + + clt, err := mcpclient.NewInProcessClient(server.MCPServer) + require.NoError(t, err) + t.Cleanup(func() { clt.Close() }) + require.NoError(t, clt.Start(t.Context())) + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = clt.Initialize(t.Context(), initRequest) + require.NoError(t, err) + require.NoError(t, clt.Ping(t.Context())) + return clt +} diff --git a/lib/client/db/postgres/mcp/mcp.go b/lib/client/db/postgres/mcp/mcp.go new file mode 100644 index 0000000000000..6a6b8a19dc10b --- /dev/null +++ b/lib/client/db/postgres/mcp/mcp.go @@ -0,0 +1,259 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "context" + "encoding/json" + "log/slog" + "net" + "time" + + "github.com/gravitational/trace" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/mark3labs/mcp-go/mcp" + + dbmcp "github.com/gravitational/teleport/lib/client/db/mcp" + clientmcp "github.com/gravitational/teleport/lib/client/mcp" + "github.com/gravitational/teleport/lib/defaults" +) + +// queryTool is the run query MCP tool definition. +var queryTool = mcp.NewTool(dbmcp.ToolName(defaults.ProtocolPostgres, "query"), + mcp.WithDescription("Execute SQL query against PostgreSQL database connected using Teleport"), + mcp.WithString(queryToolDatabaseParam, + mcp.Required(), + mcp.Description("Teleport database resource URI where the query will be executed"), + ), + mcp.WithString(queryToolQueryParam, + mcp.Required(), + mcp.Description("PostgresSQL SQL query to execute"), + ), +) + +type database struct { + *dbmcp.Database + pool *pgxpool.Pool +} + +// Server handles PostgreSQL-specific MCP tools requests. +type Server struct { + logger *slog.Logger + databases map[string]*database +} + +// NewServer initializes a PostgreSQL MCP server, creating the database +// configurations and registering Server tools into the root server. +func NewServer(ctx context.Context, cfg *dbmcp.NewServerConfig) (dbmcp.Server, error) { + s := &Server{logger: cfg.Logger, databases: make(map[string]*database)} + + for _, db := range cfg.Databases { + if db.DatabaseUser == "" || db.DatabaseName == "" { + return nil, trace.BadParameter("database %q is missing the username and database name", db.DB.GetName()) + } + + connCfg, err := buildConnConfig(db) + if err != nil { + return nil, trace.Wrap(err) + } + + pool, err := pgxpool.NewWithConfig(ctx, connCfg) + if err != nil { + return nil, trace.BadParameter("failed to parse database %q connection config: %s", db.DB.GetName(), err) + } + + s.databases[db.ResourceURI().String()] = &database{ + Database: db, + pool: pool, + } + } + + cfg.RootServer.AddTool(queryTool, s.RunQuery) + return s, nil +} + +// Close implements dbmcp.Server. +func (s *Server) Close(context.Context) error { + for _, db := range s.databases { + db.pool.Close() + } + + return nil +} + +// RunQueryResult is the run query tool result. +type RunQueryResult struct { + // Data contains the data returned from the query. It can be empty in case + // the query doesn't return any data. + Data []map[string]any `json:"data"` + // RowsCount number of rows affected by the query or returned as data. + RowsCount int `json:"rowsCount"` + // ErrorMessage if the query wasn't successful, this field contains the + // error message. + ErrorMessage string `json:"error,omitempty"` +} + +// RunQuery tool function used to execute queries on databases. +func (s *Server) RunQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + uri, err := request.RequireString(queryToolDatabaseParam) + if err != nil { + return s.wrapErrorResult(ctx, nil, trace.Wrap(err)) + } + + sql, err := request.RequireString(queryToolQueryParam) + if err != nil { + return s.wrapErrorResult(ctx, nil, trace.Wrap(err)) + } + + db, err := s.getDatabase(uri) + if err != nil { + return s.wrapErrorResult(ctx, nil, err) + } + + // TODO(gabrielcorado): ensure the connection used is consistent for the + // session, making most of its queries to be present in a single audit + // session/recording. + rows, err := db.pool.Query(ctx, sql) + if err != nil { + return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err) + } + + // Returned rows are being closed by this function. + result, err := buildQueryResult(rows) + if err != nil { + return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err) + } + + return mcp.NewToolResultText(result), nil +} + +func (s *Server) wrapErrorResult(ctx context.Context, externalRetriever dbmcp.ExternalErrorRetriever, toolErr error) (*mcp.CallToolResult, error) { + s.logger.ErrorContext(ctx, "error while querying database", "error", toolErr) + out, err := json.Marshal(RunQueryResult{ErrorMessage: dbmcp.FormatErrorMessage(externalRetriever, toolErr).Error()}) + return mcp.NewToolResultError(string(out)), trace.Wrap(err) +} + +// buildQueryResult takes a the response from pgx and converts into a JSON +// format (which will be returned to LLMs). +func buildQueryResult(rows pgx.Rows) (string, error) { + // Just ensure the rows is always closed. It is safe if this is called + // multiple times. + defer rows.Close() + + var data []map[string]any + columns := rows.FieldDescriptions() + + for rows.Next() { + values, err := rows.Values() + if err != nil { + return "", trace.Wrap(err) + } + + item := make(map[string]any, len(values)) + for i, v := range values { + item[columns[i].Name] = v + } + + data = append(data, item) + } + + // Close the rows to finish consuming it. Depending on the its type + // we can only collect the command tag after rows is closed. + rows.Close() + commandTag := rows.CommandTag() + + // Initialize the slice so the resulting JSON will have an empty array + // instead of null. + if len(data) == 0 && commandTag.Select() { + data = []map[string]any{} + } + + out, err := json.Marshal(RunQueryResult{ + Data: data, + RowsCount: int(commandTag.RowsAffected()), + }) + return string(out), trace.Wrap(err) +} + +func (s *Server) getDatabase(uri string) (*database, error) { + if !clientmcp.IsDatabaseResourceURI(uri) { + return nil, dbmcp.WrongDatabaseURIFormatError + } + + db, ok := s.databases[uri] + if !ok { + return nil, dbmcp.DatabaseNotFoundError + } + + return db, nil +} + +func buildConnConfig(db *dbmcp.Database) (*pgxpool.Config, error) { + config, err := pgxpool.ParseConfig("postgres://" + db.Addr) + if err != nil { + return nil, trace.Wrap(err) + } + + config.MaxConnIdleTime = connectionIdleTime + config.MaxConns = int32(maxConnections) + + config.ConnConfig.LookupFunc = func(ctx context.Context, host string) ([]string, error) { + return db.LookupFunc(ctx, host) + } + config.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + return db.DialContextFunc(ctx, network, addr) + } + + config.ConnConfig.User = db.DatabaseUser + config.ConnConfig.Database = db.DatabaseName + config.ConnConfig.ConnectTimeout = defaults.DatabaseConnectTimeout + config.ConnConfig.RuntimeParams = map[string]string{ + applicationNameParamName: applicationNameParamValue, + } + config.ConnConfig.TLSConfig = nil + // Use simple protocol to have a closer behavior to DB REPL and psql. + // + // This also avoids each query being prepared, binded and executed, reducing + // the amount of audit events per query executed. + config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + return config, nil +} + +const ( + // queryToolDatabaseParam is the name of the database URI param name from + // query tool. + queryToolDatabaseParam = "database" + // queryToolQueryParam is the name of the query param name from query tool. + queryToolQueryParam = "query" + + // applicationNameParamName defines the application name parameter name. + // + // https://www.postgresql.org/docs/17/libpq-connect.html#LIBPQ-CONNECT-APPLICATION-NAME + applicationNameParamName = "application_name" + // applicationNameParamValue defines the application name parameter value. + applicationNameParamValue = "teleport-mcp" + // connectionIdleTime is the max connection idle time before it gets closed + // automatically. + connectionIdleTime = 1 * time.Minute + // maxConnections defines the max number of concurrent connections the pool + // can have. + // + // Given the current MCP usage, the clients will most likely do one query at + // time, even on multiple sessions. + maxConnections = 1 +) diff --git a/lib/client/db/postgres/mcp/mcp_test.go b/lib/client/db/postgres/mcp/mcp_test.go new file mode 100644 index 0000000000000..efdf69f5804e6 --- /dev/null +++ b/lib/client/db/postgres/mcp/mcp_test.go @@ -0,0 +1,234 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "context" + "encoding/json" + "log/slog" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/types" + dbmcp "github.com/gravitational/teleport/lib/client/db/mcp" + clientmcp "github.com/gravitational/teleport/lib/client/mcp" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils/listener" +) + +func TestFormatResult(t *testing.T) { + for name, tc := range map[string]struct { + rows pgx.Rows + expectedResult string + }{ + "query results": { + rows: newMockRows("SELECT 2", []string{"name", "age"}, [][]any{{"Alice", 30}, {"Bob", 31}}), + expectedResult: `{"data":[{"age":30,"name":"Alice"},{"age":31,"name":"Bob"}],"rowsCount":2}`, + }, + "empty query results": { + rows: newMockRows("SELECT 0", []string{}, [][]any{}), + expectedResult: `{"data":[],"rowsCount":0}`, + }, + "non-data results": { + rows: newMockRows("INSERT 1", []string{}, [][]any{}), + expectedResult: `{"data":null,"rowsCount":1}`, + }, + } { + t.Run(name, func(t *testing.T) { + res, err := buildQueryResult(tc.rows) + require.NoError(t, err) + require.Equal(t, tc.expectedResult, res) + }) + } +} + +func TestFormatErrors(t *testing.T) { + // Dummy listener that always drop connections. + listener := listener.NewInMemoryListener() + t.Cleanup(func() { listener.Close() }) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + _ = conn.Close() + } + }() + + dbName := "local" + db, err := types.NewDatabaseV3(types.Metadata{ + Name: dbName, + Labels: map[string]string{"env": "test"}, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + URI: "localhost:5432", + }) + require.NoError(t, err) + dbURI := clientmcp.NewDatabaseResourceURI("root", dbName).String() + + for name, tc := range map[string]struct { + databaseURI string + databases []*dbmcp.Database + externalErrorRetriever dbmcp.ExternalErrorRetriever + expectErrorMessage require.ValueAssertionFunc + }{ + "database not found": { + databaseURI: "teleport://clusters/root/databases/not-found", + expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + require.Equal(t, dbmcp.DatabaseNotFoundError.Error(), i1) + }, + }, + "malformed database uri": { + databaseURI: "not-found", + expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + require.Equal(t, dbmcp.WrongDatabaseURIFormatError.Error(), i1) + }, + }, + "local proxy rejects connection": { + databaseURI: dbURI, + databases: []*dbmcp.Database{ + &dbmcp.Database{ + DB: db, + ClusterName: "root", + DatabaseUser: "postgres", + DatabaseName: "postgres", + Addr: listener.Addr().String(), + LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) { + return []string{"memory"}, nil + }, + DialContextFunc: listener.DialContext, + }, + }, + expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + require.Equal(t, dbmcp.LocalProxyConnectionErrorMessage, i1) + }, + }, + "relogin error": { + databaseURI: dbURI, + databases: []*dbmcp.Database{ + &dbmcp.Database{ + DB: db, + ClusterName: "root", + DatabaseUser: "postgres", + DatabaseName: "postgres", + Addr: listener.Addr().String(), + ExternalErrorRetriever: &mockErrorRetriever{err: client.ErrClientCredentialsHaveExpired}, + LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) { + return []string{"memory"}, nil + }, + DialContextFunc: listener.DialContext, + }, + }, + expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) { + require.Equal(t, dbmcp.ReloginRequiredErrorMessage, i1) + }, + }, + } { + t.Run(name, func(t *testing.T) { + logger := slog.New(slog.DiscardHandler) + rootServer := dbmcp.NewRootServer(logger) + srv, err := NewServer(t.Context(), &dbmcp.NewServerConfig{ + Logger: logger, + RootServer: rootServer, + Databases: tc.databases, + }) + require.NoError(t, err) + t.Cleanup(func() { srv.Close(t.Context()) }) + + pgSrv := srv.(*Server) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]any{ + queryToolDatabaseParam: tc.databaseURI, + queryToolQueryParam: "SELECT 1", + } + runResult, err := pgSrv.RunQuery(t.Context(), req) + require.NoError(t, err) + + require.True(t, runResult.IsError) + require.Len(t, runResult.Content, 1) + require.IsType(t, mcp.TextContent{}, runResult.Content[0]) + + content := runResult.Content[0].(mcp.TextContent) + var res RunQueryResult + require.NoError(t, json.Unmarshal([]byte(content.Text), &res), "expected result to be in JSON format") + require.Empty(t, res.Data) + tc.expectErrorMessage(t, res.ErrorMessage) + }) + } +} + +func newMockRows(commandTag string, fields []string, rows [][]any) pgx.Rows { + var fds []pgconn.FieldDescription + for _, fieldName := range fields { + fds = append(fds, pgconn.FieldDescription{Name: fieldName}) + } + return &mockRows{ + commandTag: commandTag, + descriptions: fds, + rows: rows, + } +} + +type mockRows struct { + pgx.Rows + + started bool + cursor int + + commandTag string + descriptions []pgconn.FieldDescription + rows [][]any +} + +func (mr *mockRows) FieldDescriptions() []pgconn.FieldDescription { + return mr.descriptions +} + +func (mr *mockRows) Next() bool { + if !mr.started { + mr.started = true + return len(mr.rows) > 0 + } + + mr.cursor += 1 + return len(mr.rows) > mr.cursor +} + +func (mr *mockRows) Values() ([]any, error) { + return mr.rows[mr.cursor], nil +} + +func (mr *mockRows) CommandTag() pgconn.CommandTag { + return pgconn.NewCommandTag(mr.commandTag) +} + +func (mr *mockRows) Close() {} + +type mockErrorRetriever struct { + err error +} + +func (mr *mockErrorRetriever) RetrieveError() error { + return mr.err +} diff --git a/lib/client/local_proxy_middleware.go b/lib/client/local_proxy_middleware.go index 1f6f7bf9622c1..9f4fd19670d67 100644 --- a/lib/client/local_proxy_middleware.go +++ b/lib/client/local_proxy_middleware.go @@ -52,6 +52,9 @@ type CertChecker struct { cert tls.Certificate certMu sync.Mutex + + err error + errMu sync.Mutex } var _ alpnproxy.LocalProxyMiddleware = (*CertChecker)(nil) @@ -142,15 +145,19 @@ func (c *CertChecker) SetCert(cert tls.Certificate) { // GetOrIssueCert gets the CertChecker's certificate, or issues a new // certificate if the it is invalid (e.g. expired) or missing. -func (c *CertChecker) GetOrIssueCert(ctx context.Context) (tls.Certificate, error) { +func (c *CertChecker) GetOrIssueCert(ctx context.Context) (cert tls.Certificate, err error) { c.certMu.Lock() defer c.certMu.Unlock() + defer func() { + c.setError(err) + }() + if err := c.checkCert(); err == nil { return c.cert, nil } - cert, err := c.certIssuer.IssueCert(ctx) + cert, err = c.certIssuer.IssueCert(ctx) if err != nil { return tls.Certificate{}, trace.Wrap(err) } @@ -170,6 +177,13 @@ func (c *CertChecker) GetOrIssueCert(ctx context.Context) (tls.Certificate, erro return c.cert, nil } +// RetrieveError retrieves the happened on while retrieving certificates. +func (c *CertChecker) RetrieveError() error { + c.errMu.Lock() + defer c.errMu.Unlock() + return c.err +} + func (c *CertChecker) checkCert() error { leaf, err := utils.TLSCertLeaf(c.cert) if err != nil { @@ -184,6 +198,12 @@ func (c *CertChecker) checkCert() error { return trace.Wrap(c.certIssuer.CheckCert(leaf)) } +func (c *CertChecker) setError(err error) { + c.errMu.Lock() + defer c.errMu.Unlock() + c.err = err +} + // CertIssuer checks and issues certs. type CertIssuer interface { // CheckCert checks that an existing certificate is valid. diff --git a/lib/client/local_proxy_middleware_test.go b/lib/client/local_proxy_middleware_test.go index cfb79e335601e..48d99c859c39d 100644 --- a/lib/client/local_proxy_middleware_test.go +++ b/lib/client/local_proxy_middleware_test.go @@ -44,10 +44,12 @@ func TestCertChecker(t *testing.T) { // certChecker should issue a new cert on first request. cert, err := certChecker.GetOrIssueCert(ctx) require.NoError(t, err) + require.NoError(t, certChecker.RetrieveError()) // subsequent calls should return the same cert. sameCert, err := certChecker.GetOrIssueCert(ctx) require.NoError(t, err) + require.NoError(t, certChecker.RetrieveError()) require.Equal(t, cert, sameCert) // If the current cert expires it should be reissued. @@ -56,6 +58,7 @@ func TestCertChecker(t *testing.T) { cert, err = certChecker.GetOrIssueCert(ctx) require.NoError(t, err) + require.NoError(t, certChecker.RetrieveError()) require.NotEqual(t, cert, expiredCert) // If the current cert fails certIssuer checks, a new one should be issued. @@ -64,12 +67,20 @@ func TestCertChecker(t *testing.T) { cert, err = certChecker.GetOrIssueCert(ctx) require.NoError(t, err) + require.NoError(t, certChecker.RetrieveError()) require.NotEqual(t, cert, badCert) // If issuing a new cert fails, an error is returned. certIssuer.issueErr = trace.BadParameter("failed to issue cert") _, err = certChecker.GetOrIssueCert(ctx) require.ErrorIs(t, err, certIssuer.issueErr, "expected error %v but got %v", certIssuer.issueErr, err) + require.ErrorIs(t, certChecker.RetrieveError(), err, "expected retrieve error to be the same get error but got: %v", certChecker.RetrieveError()) + + // If the problem is solved, the error is clean up. + certIssuer.issueErr = nil + _, err = certChecker.GetOrIssueCert(ctx) + require.NoError(t, err) + require.NoError(t, certChecker.RetrieveError()) } func TestLocalCertGenerator(t *testing.T) { diff --git a/lib/client/mcp/uri.go b/lib/client/mcp/uri.go new file mode 100644 index 0000000000000..a932a67c776ec --- /dev/null +++ b/lib/client/mcp/uri.go @@ -0,0 +1,153 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "net/url" + "strings" + + "github.com/gravitational/trace" + "github.com/ucarion/urlpath" +) + +var ( + // clusterURITemplate is the base cluster template. + clusterURITemplate = urlpath.New("/clusters/:cluster/*") + // databaseURITemplate template used to parse database resource URIs. + databaseURITemplate = urlpath.New("/clusters/:cluster/databases/:dbName") +) + +const ( + // resourceScheme scheme used by Teleport MCP resources. + resourceScheme = "teleport" + + // databaseNameQueryParamName is the query param name used for database + // name. + databaseNameQueryParamName = "dbName" + // databaseUserQueryParamName is the query param name used for database + // user. + databaseUserQueryParamName = "dbUser" +) + +// ResourceURI is a Teleport MCP resource URI. +// +// Query parameters are not covered on the MCP spec but we use them to provide +// additional information about the resource connection. For example, if the +// resource requires a "username" value, this value is provided using the query +// params. +// +// https://modelcontextprotocol.io/docs/concepts/resources#resource-uris +type ResourceURI struct { + url url.URL +} + +// ParseResourceURI parses a raw resource URI into a Teleport URI. +func ParseResourceURI(uri string) (*ResourceURI, error) { + parsedURL, err := url.Parse(uri) + if err != nil { + return nil, trace.BadParameter("invalid resource URI format: %s", err) + } + + if parsedURL.Scheme != resourceScheme { + return nil, trace.BadParameter("invalid URI scheme, must be %q", resourceScheme) + } + + return &ResourceURI{url: *parsedURL}, nil +} + +// NewDatabaseResourceURI creates a new database resource URI. +func NewDatabaseResourceURI(cluster, databaseName string) ResourceURI { + pathWithHost, _ := databaseURITemplate.Build(urlpath.Match{ + Params: map[string]string{ + "cluster": cluster, + "dbName": databaseName, + }, + }) + + return ResourceURI{ + url: url.URL{ + Scheme: resourceScheme, + Path: strings.TrimPrefix(pathWithHost, "/"), + }, + } +} + +// GetDatabaseServiceName returns the Teleport cluster name. +func (u ResourceURI) GetClusterName() string { + if match, ok := clusterURITemplate.Match(u.path()); ok { + return match.Params["cluster"] + } + + return "" +} + +// GetDatabaseServiceName returns the database service name of the resource. +// Returns empty if the resource is not a database. +func (u ResourceURI) GetDatabaseServiceName() string { + if match, ok := databaseURITemplate.Match(u.path()); ok { + return match.Params["dbName"] + } + + return "" +} + +// GetDatabaseUser returns the database username param of the resource. +// Returns empty if the resource is not a database. +func (u ResourceURI) GetDatabaseUser() string { + return u.url.Query().Get(databaseUserQueryParamName) +} + +// GetDatabaseName returns the database name param of the resource. +// Returns empty if the resource is not a database. +func (u ResourceURI) GetDatabaseName() string { + return u.url.Query().Get(databaseNameQueryParamName) +} + +// IsDatabase returns true if the resource is a database. +func (u ResourceURI) IsDatabase() bool { + return u.GetDatabaseServiceName() != "" +} + +// String returns the string representation of the resource URI (excluding the +// query params). +func (u ResourceURI) String() string { + c := u.url + c.RawQuery = "" + return c.String() +} + +// path returns the resource URI full path. We must include the hostname as the +// templates will also include them on the matching. +func (u ResourceURI) path() string { + return "/" + u.url.Hostname() + u.url.Path +} + +// IsDatabase returns true if the URI is a database resource. +func IsDatabaseResourceURI(uri string) bool { + parsed, err := ParseResourceURI(uri) + if err != nil { + return false + } + + return parsed.IsDatabase() +} + +var ( + // SampleDatabaseResource contains a sample full resource URI. This can be + // used on descriptions to show how a database resource URI looks like. + SampleDatabaseResource = NewDatabaseResourceURI("example-cluster", "myDatabase") +) diff --git a/lib/client/mcp/uri_test.go b/lib/client/mcp/uri_test.go new file mode 100644 index 0000000000000..30cdc64a2a357 --- /dev/null +++ b/lib/client/mcp/uri_test.go @@ -0,0 +1,94 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mcp + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDatabaseResourceURI(t *testing.T) { + for name, tc := range map[string]struct { + uri string + expectError bool + expectedDatabase bool + expectedServiceName string + expectedDatabaseName string + expectedDatabaseUser string + expectedClusterName string + }{ + "valid database": { + uri: "teleport://clusters/default/databases/pg?dbName=database&dbUser=user", + expectedDatabase: true, + expectedServiceName: "pg", + expectedDatabaseName: "database", + expectedDatabaseUser: "user", + expectedClusterName: "default", + }, + "valid database without params": { + uri: "teleport://clusters/default/databases/pg", + expectedDatabase: true, + expectedServiceName: "pg", + expectedDatabaseName: "", + expectedDatabaseUser: "", + expectedClusterName: "default", + }, + "random resource": { + uri: "teleport://clusters/default/random/random-resource", + expectedDatabase: false, + expectedServiceName: "", + expectedDatabaseName: "", + expectedDatabaseUser: "", + expectedClusterName: "default", + }, + "generated uri": { + uri: NewDatabaseResourceURI("default", "db").String(), + expectedDatabase: true, + expectedServiceName: "db", + expectedDatabaseName: "", + expectedDatabaseUser: "", + expectedClusterName: "default", + }, + "invalid schema": { + uri: "http://databases/database", + expectError: true, + }, + "invalid uri": { + uri: "random-value", + expectError: true, + }, + } { + t.Run(name, func(t *testing.T) { + uri, err := ParseResourceURI(tc.uri) + if tc.expectError { + require.Error(t, err) + return + } + + require.NotNil(t, uri) + fmt.Println(tc.uri) + require.Equal(t, tc.expectedDatabase, IsDatabaseResourceURI(tc.uri)) + require.Equal(t, tc.expectedDatabase, uri.IsDatabase()) + require.Equal(t, tc.expectedServiceName, uri.GetDatabaseServiceName()) + require.Equal(t, tc.expectedDatabaseName, uri.GetDatabaseName()) + require.Equal(t, tc.expectedDatabaseUser, uri.GetDatabaseUser()) + require.Equal(t, tc.expectedClusterName, uri.GetClusterName()) + }) + } +} diff --git a/lib/utils/cli.go b/lib/utils/cli.go index a025a301a350b..6a8ca6067ead6 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -54,6 +54,8 @@ const ( LoggingForDaemon LoggingPurpose = iota // LoggingForCLI configures logging for user face utilities (tctl, tsh). LoggingForCLI + // LoggingForMCP configures logging for MCP servers. + LoggingForMCP ) // LoggingFormat defines the possible logging output formats. @@ -100,7 +102,7 @@ func IsTerminal(w io.Writer) bool { } // InitLogger configures the global logger for a given purpose / verbosity level -func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption) error { +func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption) (*slog.Logger, error) { var o logOpts for _, opt := range opts { @@ -110,23 +112,28 @@ func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption) // If debug or trace logging is not enabled for CLIs, // then discard all log output. if purpose == LoggingForCLI && level > slog.LevelDebug { - slog.SetDefault(slog.New(slog.DiscardHandler)) - return nil + logger := slog.New(slog.DiscardHandler) + slog.SetDefault(logger) + return logger, nil } var output string - if o.osLogSubsystem != "" { + switch { + case o.osLogSubsystem != "": output = logutils.LogOutputOSLog + case purpose == LoggingForMCP: + output = logutils.LogOutputMCP + o.format = LogFormatJSON } - _, _, err := logutils.Initialize(logutils.Config{ + logger, _, err := logutils.Initialize(logutils.Config{ Severity: level.String(), Format: o.format, EnableColors: IsTerminal(os.Stderr), Output: output, OSLogSubsystem: o.osLogSubsystem, }) - return trace.Wrap(err) + return logger, trace.Wrap(err) } var initTestLoggerOnce = sync.Once{} diff --git a/lib/utils/listener/memory.go b/lib/utils/listener/memory.go new file mode 100644 index 0000000000000..5f7795a69952d --- /dev/null +++ b/lib/utils/listener/memory.go @@ -0,0 +1,109 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package listener + +import ( + "context" + "errors" + "io" + "net" + "sync" +) + +// InMemoryListener is a in-memory implementation of a net.Listener. +type InMemoryListener struct { + connCh chan net.Conn + closeCh chan struct{} + closeOnce sync.Once +} + +// Accept implements net.Listener. +func (m *InMemoryListener) Accept() (net.Conn, error) { + select { + case <-m.closeCh: + return nil, io.EOF + default: + } + + for { + select { + case conn := <-m.connCh: + return conn, nil + case <-m.closeCh: + return nil, io.EOF + } + } +} + +// Addr implements net.Listener. +func (m *InMemoryListener) Addr() net.Addr { + return defaultMemoryAddr +} + +// Close implements net.Listener. +func (m *InMemoryListener) Close() error { + m.closeOnce.Do(func() { close(m.closeCh) }) + return nil +} + +// DialContext dials the memory listener, creating a new net.Conn. +// +// This function satisfies net.Dialer.DialContext signature. +func (m *InMemoryListener) DialContext(ctx context.Context, _ string, _ string) (net.Conn, error) { + select { + case <-m.closeCh: + return nil, ErrListenerClosed + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + serverConn, clientConn := net.Pipe() + + select { + case m.connCh <- serverConn: + case <-ctx.Done(): + // In this case the connection was not accepted in time by the server + // and the dial context is done. To avoid having the server using an + // orphned connection we should close it. + _ = serverConn.Close() + return nil, ctx.Err() + } + return clientConn, nil +} + +// ErrListenerClosed is the error returned by dial when the listener is closed. +var ErrListenerClosed = errors.New("in-memory listener closed") + +// NewInMemoryListener initializes a new in-memory listener. +func NewInMemoryListener() *InMemoryListener { + return &InMemoryListener{ + connCh: make(chan net.Conn), + closeCh: make(chan struct{}), + } +} + +var _ net.Listener = (*InMemoryListener)(nil) + +type memoryAddr string + +func (m memoryAddr) Network() string { return string(m) } +func (m memoryAddr) String() string { return string(m) } + +var defaultMemoryAddr = memoryAddr("memory") + +var _ net.Addr = (*memoryAddr)(nil) diff --git a/lib/utils/listener/memory_test.go b/lib/utils/listener/memory_test.go new file mode 100644 index 0000000000000..024a96031de16 --- /dev/null +++ b/lib/utils/listener/memory_test.go @@ -0,0 +1,132 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package listener + +import ( + "context" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMemoryListenerClient(t *testing.T) { + var wg sync.WaitGroup + expectedMessage := "hello from server" + + listener := NewInMemoryListener() + t.Cleanup(func() { listener.Close() }) + + wg.Add(1) + go func() { + defer wg.Done() + + conn, err := listener.Accept() + if err != nil { + return + } + t.Cleanup(func() { conn.Close() }) + + _, _ = conn.Write([]byte(expectedMessage)) + }() + + // To avoid blocking in case the server is not working correctly, wrap + // the client connection into a eventually loop. + require.EventuallyWithT(t, func(collect *assert.CollectT) { + conn, err := listener.DialContext(t.Context(), "", "") + require.NoError(collect, err) + + buf := make([]byte, len(expectedMessage)) + n, err := conn.Read(buf[0:]) + require.NoError(collect, err) + require.Equal(collect, len(expectedMessage), n) + require.Equal(collect, expectedMessage, string(buf[:n])) + }, 50*time.Millisecond, 10*time.Millisecond) + + require.Eventually(t, func() bool { + wg.Wait() + return true + }, 50*time.Millisecond, 10*time.Millisecond) +} + +func TestMemoryListenerServer(t *testing.T) { + var wg sync.WaitGroup + expectedMessage := "hello from client" + + listener := NewInMemoryListener() + t.Cleanup(func() { listener.Close() }) + + wg.Add(1) + go func() { + defer wg.Done() + + conn, err := listener.DialContext(t.Context(), "", "") + if err != nil { + return + } + + _, _ = conn.Write([]byte(expectedMessage)) + }() + + // To avoid blocking in case the client is not working correctly, wrap + // the server accept connection into a eventually loop. + require.EventuallyWithT(t, func(collect *assert.CollectT) { + conn, err := listener.Accept() + require.NoError(collect, err) + + buf := make([]byte, len(expectedMessage)) + n, err := conn.Read(buf[0:]) + require.NoError(collect, err) + require.Equal(collect, len(expectedMessage), n) + require.Equal(collect, expectedMessage, string(buf[:n])) + }, 50*time.Millisecond, 10*time.Millisecond) + + // Close the listener and expect subsequent accept calls to return error. + listener.Close() + require.EventuallyWithT(t, func(collect *assert.CollectT) { + _, err := listener.Accept() + require.Error(collect, err) + require.ErrorIs(collect, err, io.EOF) + }, 50*time.Millisecond, 10*time.Millisecond) + + require.Eventually(t, func() bool { + wg.Wait() + return true + }, 50*time.Millisecond, 10*time.Millisecond) +} + +func TestMemoryListenerDialTimeout(t *testing.T) { + listener := NewInMemoryListener() + t.Cleanup(func() { listener.Close() }) + + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + _, err := listener.DialContext(ctx, "", "") + require.Error(collect, err) + require.ErrorIs(collect, err, context.DeadlineExceeded) + }, 100*time.Millisecond, 10*time.Millisecond) + + require.Never(t, func() bool { + _, _ = listener.Accept() + return true + }, 50*time.Millisecond, 10*time.Millisecond, "expected server to not have received connections") +} diff --git a/lib/utils/log/log.go b/lib/utils/log/log.go index 3e68086500e29..48f945c12162c 100644 --- a/lib/utils/log/log.go +++ b/lib/utils/log/log.go @@ -57,6 +57,10 @@ const ( LogOutputSyslog = "syslog" // LogOutputOSLog represents os_log, the unified logging system on macOS, as the destination for logs. LogOutputOSLog = "os_log" + // LogOutputMCP defines to where the MCP command logs will be directed to. + // The stdout is exclusively used as the MCP server transport, leaving only + // stderr available. + LogOutputMCP = "stderr" ) // Initialize configures the default global logger based on the diff --git a/tool/tsh/common/logger.go b/tool/tsh/common/logger.go index df3e97595d259..c3450a5acc51f 100644 --- a/tool/tsh/common/logger.go +++ b/tool/tsh/common/logger.go @@ -38,7 +38,7 @@ const ( // It is called twice, first soon after launching tsh before argv is parsed and then again after // kingpin parses argv. This makes it possible to debug early startup functionality, particularly // command aliases. -func initLogger(cf *CLIConf, opts loggingOpts) error { +func initLogger(cf *CLIConf, purpose utils.LoggingPurpose, opts loggingOpts) (*slog.Logger, error) { cf.OSLog = opts.osLog cf.Debug = opts.debug || opts.osLog @@ -49,7 +49,8 @@ func initLogger(cf *CLIConf, opts loggingOpts) error { level = slog.LevelDebug } - return trace.Wrap(utils.InitLogger(utils.LoggingForCLI, level, initLoggerOpts...)) + logger, err := utils.InitLogger(purpose, level, initLoggerOpts...) + return logger, trace.Wrap(err) } type loggingOpts struct { diff --git a/tool/tsh/common/mcp.go b/tool/tsh/common/mcp.go new file mode 100644 index 0000000000000..60f08d445233d --- /dev/null +++ b/tool/tsh/common/mcp.go @@ -0,0 +1,31 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import "github.com/alecthomas/kingpin/v2" + +type mcpCommands struct { + dbStart *mcpDBStartCommand +} + +func newMCPCommands(app *kingpin.Application) *mcpCommands { + mcp := app.Command("mcp", "View and control proxied MCP servers.") + db := mcp.Command("db", "Database access for MCP servers.") + return &mcpCommands{ + dbStart: newMCPDBCommand(db), + } +} diff --git a/tool/tsh/common/mcp_db.go b/tool/tsh/common/mcp_db.go new file mode 100644 index 0000000000000..699ae979c3ca5 --- /dev/null +++ b/tool/tsh/common/mcp_db.go @@ -0,0 +1,234 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "context" + "log/slog" + + "github.com/alecthomas/kingpin/v2" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/client" + dbmcp "github.com/gravitational/teleport/lib/client/db/mcp" + pgmcp "github.com/gravitational/teleport/lib/client/db/postgres/mcp" + "github.com/gravitational/teleport/lib/client/mcp" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/srv/alpnproxy" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/listener" +) + +// mcpDBStartCommand implements `tsh mcp db start` command. +type mcpDBStartCommand struct { + *kingpin.CmdClause + + databaseURIs []string +} + +func newMCPDBCommand(parent *kingpin.CmdClause) *mcpDBStartCommand { + cmd := &mcpDBStartCommand{ + CmdClause: parent.Command("start", "Start a local MCP server for database access").Hidden(), + } + + cmd.Arg("uris", "List of database MCP resource URIs that will be served by the server").Required().StringsVar(&cmd.databaseURIs) + return cmd +} + +func (c *mcpDBStartCommand) run(cf *CLIConf) error { + logger, err := initLogger(cf, utils.LoggingForMCP, parseLoggingOptsFromEnvAndArgv(cf)) + if err != nil { + return trace.Wrap(err) + } + + registry := defaultDBMCPRegistry + if cf.databaseMCPRegistryOverride != nil { + registry = cf.databaseMCPRegistryOverride + } + + tc, err := makeClient(cf) + if err != nil { + return trace.Wrap(err) + } + + // Avoid any input request on the command execution. This is required, + // otherwise the MCP clients will be stuck waiting for a response. + tc.NonInteractive = false + + configuredDatabases := map[string]struct{}{} + uris := make([]*mcp.ResourceURI, len(c.databaseURIs)) + for i, rawURI := range c.databaseURIs { + uri, err := mcp.ParseResourceURI(rawURI) + if err != nil { + return trace.Wrap(err) + } + + if !uri.IsDatabase() { + return trace.BadParameter("%q resource must be a database", rawURI) + } + + // TODO(gabrielcorado): support databases from different clusters. + if uri.GetClusterName() != tc.SiteName { + return trace.BadParameter("Databases must be from the same cluster (%q). %q is from a different cluster.", tc.SiteName, rawURI) + } + + if _, ok := configuredDatabases[uri.String()]; ok { + return trace.BadParameter("Database %q was configured twice. MCP servers only support serving a database service only once.", uri.String()) + } + + configuredDatabases[uri.String()] = struct{}{} + uris[i] = uri + } + + server := dbmcp.NewRootServer(logger) + allDatabases, closeLocalProxies, err := c.prepareDatabases(cf, tc, registry, uris, logger, server) + if err != nil { + return trace.Wrap(err) + } + defer closeLocalProxies() + + for protocol, newServerFunc := range registry { + databases := allDatabases[protocol] + if len(databases) == 0 { + continue + } + + srv, err := newServerFunc(cf.Context, &dbmcp.NewServerConfig{ + Logger: logger, + RootServer: server, + Databases: databases, + }) + if err != nil { + return trace.Wrap(err) + } + defer srv.Close(cf.Context) + } + + return trace.Wrap(server.ServeStdio(cf.Context, cf.Stdin(), cf.Stdout())) +} + +// closeLocalProxyFunc function used to close local proxy listeners. +type closeLocalProxyFunc func() error + +// prepareDatabases based on the available MCP servers, initialize the database +// local proxy and generate the MCP database. +func (c *mcpDBStartCommand) prepareDatabases( + cf *CLIConf, + tc *client.TeleportClient, + registry dbmcp.Registry, + uris []*mcp.ResourceURI, + logger *slog.Logger, + server *dbmcp.RootServer, +) (map[string][]*dbmcp.Database, closeLocalProxyFunc, error) { + var ( + ctx = cf.Context + dbsPerProtocol = make(map[string][]*dbmcp.Database) + closeFuncs []closeLocalProxyFunc + ) + + for _, uri := range uris { + serviceName := uri.GetDatabaseServiceName() + dbUser := uri.GetDatabaseUser() + dbName := uri.GetDatabaseName() + + route := tlsca.RouteToDatabase{ + ServiceName: serviceName, + Username: dbUser, + Database: dbName, + } + + info, err := getDatabaseInfo(cf, tc, []tlsca.RouteToDatabase{route}) + if err != nil { + logger.InfoContext(ctx, "failed to retrieve database information", "database", serviceName, "error", err) + continue + } + + db, err := info.GetDatabase(ctx, tc) + if err != nil { + logger.InfoContext(ctx, "failed to load database information", "database", serviceName, "error", err) + continue + } + + if !registry.IsSupported(db.GetProtocol()) { + logger.InfoContext(ctx, "database protocol unsupported, skipping it", "database", serviceName, "protocol", db.GetProtocol()) + continue + } + + route.Protocol = db.GetProtocol() + cc := client.NewDBCertChecker(tc, route, nil, client.WithTTL(tc.KeyTTL)) + // This avoids having the middleware to refresh the certificate if there + // is a certificate available on disk. + cert, err := loadDBCertificate(tc, route.ServiceName) + if err == nil { + cc.SetCert(cert) + } + + listener := listener.NewInMemoryListener() + lp, err := alpnproxy.NewLocalProxy( + makeBasicLocalProxyConfig(ctx, tc, listener, tc.InsecureSkipVerify), + alpnproxy.WithDatabaseProtocol(route.Protocol), + alpnproxy.WithMiddleware(cc), + alpnproxy.WithClusterCAsIfConnUpgrade(ctx, tc.RootClusterCACertPool), + ) + if err != nil { + _ = listener.Close() + logger.ErrorContext(ctx, "failed to start local proxy for database, skipping it", "database", db.GetName(), "error", err) + continue + } + go func() { + defer lp.Close() + if err = lp.Start(ctx); err != nil { + logger.WarnContext(ctx, "failed to start local ALPN proxy", "error", err) + } + }() + + mcpDB := &dbmcp.Database{ + DB: db, + ClusterName: uri.GetClusterName(), + DatabaseUser: dbUser, + DatabaseName: dbName, + Addr: listener.Addr().String(), + ExternalErrorRetriever: cc, + // Since we're using in-memory listener we don't need to resolve the + // address. + LookupFunc: func(ctx context.Context, host string) (addrs []string, err error) { + return []string{listener.Addr().String()}, nil + }, + DialContextFunc: listener.DialContext, + } + dbsPerProtocol[db.GetProtocol()] = append(dbsPerProtocol[db.GetProtocol()], mcpDB) + server.RegisterDatabase(mcpDB) + closeFuncs = append(closeFuncs, listener.Close) + } + + return dbsPerProtocol, func() error { + var errs []error + for _, closeFunc := range closeFuncs { + errs = append(errs, closeFunc()) + } + + return trace.NewAggregate(errs...) + }, nil +} + +var ( + // defaultDBMCPRegistry is the default database access MCP servers registry. + defaultDBMCPRegistry = map[string]dbmcp.NewServerFunc{ + defaults.ProtocolPostgres: pgmcp.NewServer, + } +) diff --git a/tool/tsh/common/mcp_db_test.go b/tool/tsh/common/mcp_db_test.go new file mode 100644 index 0000000000000..123657d212841 --- /dev/null +++ b/tool/tsh/common/mcp_db_test.go @@ -0,0 +1,223 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "context" + "io" + "testing" + "time" + + mcpclient "github.com/mark3labs/mcp-go/client" + mcptransport "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + dbmcp "github.com/gravitational/teleport/lib/client/db/mcp" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/service/servicecfg" + testserver "github.com/gravitational/teleport/tool/teleport/testenv" +) + +func TestMCPDBCommand(t *testing.T) { + tmpHomePath := t.TempDir() + connector := mockConnector(t) + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetDatabaseUsers([]string{"postgres"}) + alice.SetDatabaseNames([]string{"postgres"}) + alice.SetRoles([]string{"access"}) + + authProcess := testserver.MakeTestServer( + t, + testserver.WithClusterName(t, "root"), + testserver.WithBootstrap(connector, alice), + testserver.WithConfig(func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + { + Name: "postgres1", + Protocol: defaults.ProtocolPostgres, + URI: "external-pg:5432", + }, + { + Name: "postgres2", + Protocol: defaults.ProtocolPostgres, + URI: "external-pg:5432", + }, + { + Name: "mysql-local", + Protocol: defaults.ProtocolMySQL, + URI: "external-mysql:3306", + }, + } + }), + ) + + authServer := authProcess.GetAuthServer() + require.NotNil(t, authServer) + + proxyAddr, err := authProcess.ProxyWebAddr() + require.NoError(t, err) + + err = Run(t.Context(), []string{ + "login", "--insecure", "--debug", "--proxy", proxyAddr.String(), + }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, alice, connector.GetName())) + require.NoError(t, err) + + stdin, writer := io.Pipe() + reader, stdout := io.Pipe() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + executionCh := make(chan error) + go func() { + executionCh <- Run(ctx, []string{ + "mcp", + "db", + "start", + "teleport://clusters/root/databases/postgres1?dbUser=postgres&dbName=postgres", + "teleport://clusters/root/databases/postgres2?dbUser=postgres&dbName=postgres", + }, setHomePath(tmpHomePath), func(c *CLIConf) error { + c.overrideStdin = stdin + c.OverrideStdout = stdout + // MCP server logs are going to be discarded. + c.overrideStderr = io.Discard + c.databaseMCPRegistryOverride = map[string]dbmcp.NewServerFunc{ + defaults.ProtocolPostgres: func(ctx context.Context, nsc *dbmcp.NewServerConfig) (dbmcp.Server, error) { + return &testDatabaseMCP{}, nil + }, + } + return nil + }) + }() + + clt := mcpclient.NewClient(mcptransport.NewIO(reader, writer, nil /* logging */)) + require.NoError(t, clt.Start(t.Context())) + + req := mcp.InitializeRequest{} + req.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + req.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + _, err = clt.Initialize(t.Context(), req) + require.NoError(collect, err) + require.NoError(collect, clt.Ping(t.Context())) + }, time.Second, 100*time.Millisecond) + + // Stop the MCP server command and wait until it is finshed. + cancel() + select { + case err := <-executionCh: + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + case <-time.After(10 * time.Second): + require.Fail(t, "expected the execution to be completed") + } +} + +func TestMCPDBCommandFailures(t *testing.T) { + tmpHomePath := t.TempDir() + connector := mockConnector(t) + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetDatabaseUsers([]string{"postgres"}) + alice.SetDatabaseNames([]string{"postgres"}) + alice.SetRoles([]string{"access"}) + clusterName := "root" + + authProcess := testserver.MakeTestServer( + t, + testserver.WithClusterName(t, clusterName), + testserver.WithBootstrap(connector, alice), + testserver.WithConfig(func(cfg *servicecfg.Config) { + cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) + cfg.Databases.Enabled = true + cfg.Databases.Databases = []servicecfg.Database{ + { + Name: "postgres1", + Protocol: defaults.ProtocolPostgres, + URI: "external-pg:5432", + }, + { + Name: "postgres2", + Protocol: defaults.ProtocolPostgres, + URI: "external-pg:5432", + }, + { + Name: "mysql-local", + Protocol: defaults.ProtocolMySQL, + URI: "external-mysql:3306", + }, + } + }), + ) + + authServer := authProcess.GetAuthServer() + require.NotNil(t, authServer) + + proxyAddr, err := authProcess.ProxyWebAddr() + require.NoError(t, err) + + withMockedMCPServers := func(c *CLIConf) error { + c.databaseMCPRegistryOverride = map[string]dbmcp.NewServerFunc{ + defaults.ProtocolPostgres: func(ctx context.Context, nsc *dbmcp.NewServerConfig) (dbmcp.Server, error) { + return &testDatabaseMCP{}, nil + }, + } + return nil + } + + err = Run(t.Context(), []string{ + "login", "--insecure", "--debug", "--proxy", proxyAddr.String(), + }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, alice, connector.GetName())) + require.NoError(t, err) + + t.Run("different clusters", func(t *testing.T) { + err := Run(t.Context(), []string{ + "mcp", + "db", + "start", + "teleport://clusters/root/databases/postgres1?dbUser=postgres&dbName=postgres", + "teleport://clusters/other/databases/postgres2?dbUser=postgres&dbName=postgres", + }, setHomePath(tmpHomePath), withMockedMCPServers) + require.Error(t, err) + }) + + t.Run("duplicated databases", func(t *testing.T) { + err := Run(t.Context(), []string{ + "mcp", + "db", + "start", + "teleport://clusters/root/databases/postgres1?dbUser=postgres&dbName=postgres", + "teleport://clusters/root/databases/postgres1?dbUser=readonly&dbName=postgres", + }, setHomePath(tmpHomePath), withMockedMCPServers) + require.Error(t, err) + }) +} + +// testDatabaseMCP is a noop database MCP server. +type testDatabaseMCP struct{} + +func (s *testDatabaseMCP) Close(_ context.Context) error { return nil } diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 550811255519b..c6612cd7f1ba6 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -79,6 +79,7 @@ import ( benchmarkdb "github.com/gravitational/teleport/lib/benchmark/db" "github.com/gravitational/teleport/lib/client" dbprofile "github.com/gravitational/teleport/lib/client/db" + dbmcp "github.com/gravitational/teleport/lib/client/db/mcp" "github.com/gravitational/teleport/lib/client/identityfile" "github.com/gravitational/teleport/lib/defaults" dtauthn "github.com/gravitational/teleport/lib/devicetrust/authn" @@ -618,6 +619,10 @@ type CLIConf struct { // atomic here is overkill as the CLIConf is generally consumed sequentially. However, occasionally // we need concurrency safety, such as for [forEachProfileParallel]. clientStoreSet int32 + + // databaseMCPRegistryOverride overrides database access MCP servers + // registry. used in tests. + databaseMCPRegistryOverride dbmcp.Registry } // Stdout returns the stdout writer. @@ -785,7 +790,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { // run early to enable debug logging if env var is set. // this makes it possible to debug early startup functionality, particularly command aliases. - if err := initLogger(&cf, parseLoggingOptsFromEnv()); err != nil { + if _, err := initLogger(&cf, utils.LoggingForCLI, parseLoggingOptsFromEnv()); err != nil { printInitLoggerError(err) } @@ -1340,6 +1345,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { gitCmd := newGitCommands(app) pivCmd := newPIVCommands(app) + mcpCmd := newMCPCommands(app) + if runtime.GOOS == constants.WindowsOS { bench.Hidden() } @@ -1430,7 +1437,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { // Enable debug logging if requested by --debug. // If TELEPORT_DEBUG was set and --debug/--no-debug was not passed, debug logs were already // enabled by a prior call to initLogger. - if err := initLogger(&cf, parseLoggingOptsFromEnvAndArgv(&cf)); err != nil { + if _, err := initLogger(&cf, utils.LoggingForCLI, parseLoggingOptsFromEnvAndArgv(&cf)); err != nil { printInitLoggerError(err) } @@ -1745,6 +1752,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = gitCmd.clone.run(&cf) case pivCmd.agent.FullCommand(): err = pivCmd.agent.run(&cf) + case mcpCmd.dbStart.FullCommand(): + err = mcpCmd.dbStart.run(&cf) default: // Handle commands that might not be available. switch {