diff --git a/lib/client/db/mcp/errors.go b/lib/client/db/mcp/errors.go
new file mode 100644
index 0000000000000..8ab0112e9332a
--- /dev/null
+++ b/lib/client/db/mcp/errors.go
@@ -0,0 +1,87 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "errors"
+ "io"
+ "strings"
+
+ "github.com/gravitational/trace"
+
+ apiclient "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/lib/client/mcp"
+)
+
+// ExtenralErrorRetriever returns an external error that might have happened.
+//
+// MCP servers don't have knowledge of other processes that might fail during
+// their execution, such as authentication failures. This provider can be used
+// to give them the necessary context to provide more accurate user messages.
+type ExternalErrorRetriever interface {
+ // RetrieveError retrieves the error if any.
+ RetrieveError() error
+}
+
+// FormatErrorMessage formats the database MCP error messages.
+// format.
+func FormatErrorMessage(retreiver ExternalErrorRetriever, err error) error {
+ if retreiver != nil {
+ err = trace.NewAggregate(retreiver.RetrieveError(), err)
+ }
+
+ switch {
+ case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired):
+ return trace.BadParameter(ReloginRequiredErrorMessage)
+ case strings.Contains(err.Error(), "connection reset by peer") || errors.Is(err, io.ErrClosedPipe):
+ return trace.BadParameter(LocalProxyConnectionErrorMessage)
+ }
+
+ return err
+}
+
+const (
+ // ReloginRequiredErrorMessage is the message returned to the MCP client
+ // when the tsh session expired.
+ ReloginRequiredErrorMessage = `It looks like your Teleport session expired,
+you must relogin (using "tsh login" on a terminal) before continue using this
+tool. After that, there is no need to update or relaunch the MCP client - just
+try using it again.`
+ // LocalProxyConnectionErrorMessage is the message returned to the MCP client when
+ // the database client cannot connect to the local proxy.
+ LocalProxyConnectionErrorMessage = `Teleport MCP server is having issue while
+establishing the database connection. You can verify the MCP logs for more
+details on what is causing this issue. After identifying and fixing the issue
+a restart on the MCP client might be necessary.`
+ // EmptyDatabasesListErrorMessage is the message returned to the MCP client when
+ // the started database server is serving no databases.
+ EmptyDatabasesListErrorMessage = `There are no active Teleport databases available
+for use on the MCP server. You can check the MCP server logs to see if any
+database was not included due to an error. You can also verify that the list
+of databases on the MCP command is correct.`
+)
+
+var (
+ // WrongDatabaseURIFormatError is the message returned to the MCP client
+ // when it sends a malformed database resource URI.
+ WrongDatabaseURIFormatError = trace.BadParameter("Malformed database resource URI. Database resources must follow the format: %q", mcp.SampleDatabaseResource)
+ // DatabaseNotFoundError is the message returned to the MCP client when the
+ // requested database is not available as MCP resource.
+ DatabaseNotFoundError = trace.NotFound(`Database not found. Only registered databases
+can be used. Ask the user to attach the database resource or list the available
+resources with %q tool`, listDatabasesToolName)
+)
diff --git a/lib/client/db/mcp/mcp.go b/lib/client/db/mcp/mcp.go
new file mode 100644
index 0000000000000..43baf36144747
--- /dev/null
+++ b/lib/client/db/mcp/mcp.go
@@ -0,0 +1,107 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "context"
+ "log/slog"
+ "net"
+
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/client/mcp"
+)
+
+// NewServerConfig configuration passed to the server constructors.
+type NewServerConfig struct {
+ Logger *slog.Logger
+ RootServer *RootServer
+ Databases []*Database
+}
+
+// NewServerFunc the MCP server constructor function definition.
+type NewServerFunc func(context.Context, *NewServerConfig) (Server, error)
+
+// Server represents a MCP server.
+type Server interface {
+ // Close closes the server.
+ Close(context.Context) error
+}
+
+// Registry represents the available databases MCP servers per protocol and
+// their constructors.
+type Registry map[string]NewServerFunc
+
+// IsSupported returns if the database protocol is supported by any MCP server
+// available.
+func (m Registry) IsSupported(protocol string) bool {
+ _, ok := m[protocol]
+ return ok
+}
+
+// LookupFunc is the function used to resolve database address. Follows the
+// net.Resolver.LookupAddr format.
+type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
+
+// DialContextFunc is a function used to dial the database. Follows the
+// net.Dialer.DialContext format.
+type DialContextFunc func(ctx context.Context, network string, addr string) (net.Conn, error)
+
+// Database the database served by an MCP server.
+type Database struct {
+ // DB contains all information from the database.
+ DB types.Database
+ // ClusterName is the cluster name where the database is located.
+ ClusterName string
+ // Addr is the address the MCP server used to create a new database
+ // connection.
+ Addr string
+ // DatabaseUser is the database username used on the connections.
+ DatabaseUser string
+ // DatabaseName is the database name used on the connections.
+ DatabaseName string
+ // ExternalErrorRetriever used to retrieve any external error that might
+ // have happened while connecting/communicating with the database.
+ ExternalErrorRetriever ExternalErrorRetriever
+ // LookupFunc is the lookup function to resolve database address.
+ LookupFunc LookupFunc
+ // DialContextFunc is the dial function used to connect to the database.
+ DialContextFunc DialContextFunc
+}
+
+// ResourceURI returns the database MCP resource URI.
+func (d Database) ResourceURI() mcp.ResourceURI {
+ return mcp.NewDatabaseResourceURI(d.ClusterName, d.DB.GetName())
+}
+
+// DatabaseResource MCP resource representation of a Teleport database.
+type DatabaseResource struct {
+ types.Metadata
+ // URI is the MCP URI resource.
+ URI string `json:"uri"`
+ // Protocol is the database protocol.
+ Protocol string `json:"protocol"`
+ // ClusterName is the cluster the database is.
+ ClusterName string `json:"cluster_name"`
+}
+
+// ToolName generates a database access tool name.
+func ToolName(protocol, name string) string {
+ return ToolPrefix + protocol + "_" + name
+}
+
+// ToolPrefix is the default tool prefix for every MCP tool.
+const ToolPrefix = "teleport_"
diff --git a/lib/client/db/mcp/server.go b/lib/client/db/mcp/server.go
new file mode 100644
index 0000000000000..b26e8bd0eaabe
--- /dev/null
+++ b/lib/client/db/mcp/server.go
@@ -0,0 +1,154 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "sync"
+
+ "github.com/ghodss/yaml"
+ "github.com/gravitational/trace"
+ "github.com/mark3labs/mcp-go/mcp"
+ mcpserver "github.com/mark3labs/mcp-go/server"
+
+ "github.com/gravitational/teleport"
+)
+
+// listDatabasesTool is the MCP tool that list all databases being served
+// (from all protocols).
+var listDatabasesTool = mcp.NewTool(listDatabasesToolName,
+ mcp.WithDescription("List database resources available to be used with Teleport tools."),
+)
+
+// RootServer database access root MCP server. It includes common MCP tools and
+// resources across different databases and serves as a root server where
+// database-specific MCP servers register their tools.
+type RootServer struct {
+ *mcpserver.MCPServer
+
+ mu sync.RWMutex
+ logger *slog.Logger
+ availableDatabases map[string]*Database
+}
+
+// NewRootServer initializes a new root MCP server.
+func NewRootServer(logger *slog.Logger) *RootServer {
+ server := &RootServer{
+ MCPServer: mcpserver.NewMCPServer(serverName, teleport.Version),
+ logger: logger,
+ availableDatabases: make(map[string]*Database),
+ }
+ server.AddTool(listDatabasesTool, server.ListDatabases)
+
+ return server
+}
+
+// ListDatabases tool function used to list all available/served databases.
+func (s *RootServer) ListDatabases(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if len(s.availableDatabases) == 0 {
+ return mcp.NewToolResultError(EmptyDatabasesListErrorMessage), nil
+ }
+
+ var res []mcp.Content
+ for _, db := range s.availableDatabases {
+ contents, err := encodeDatabaseResource(db)
+ if err != nil {
+ s.logger.ErrorContext(ctx, "error while list databases", "error", err)
+ return mcp.NewToolResultError(FormatErrorMessage(nil, err).Error()), nil
+ }
+ res = append(res, mcp.EmbeddedResource{Type: "resource", Resource: contents})
+ }
+
+ return &mcp.CallToolResult{
+ Content: res,
+ }, nil
+}
+
+// GetDatabaseResource resource handler for databases.
+func (s *RootServer) GetDatabaseResource(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ db, ok := s.availableDatabases[request.Params.URI]
+ if !ok {
+ return nil, trace.NotFound("Database is %q not available as MCP resource", request.Params.URI)
+ }
+
+ encodedDb, err := encodeDatabaseResource(db)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return []mcp.ResourceContents{encodedDb}, nil
+}
+
+// RegisterDatabase register a database on the root server. This make it
+// available as a MCP resource.
+//
+// TODO(gabrielcorado): support dynamically registering/deregistering databases
+// after the server starts.
+func (s *RootServer) RegisterDatabase(db *Database) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ uri := db.ResourceURI().String()
+ s.availableDatabases[uri] = db
+ s.AddResource(mcp.NewResource(uri, fmt.Sprintf("%s Datatabase", db.DB.GetName()), mcp.WithMIMEType(databaseResourceMIMEType)), s.GetDatabaseResource)
+}
+
+// ServeStdio starts serving the root MCP using STDIO transport.
+func (s *RootServer) ServeStdio(ctx context.Context, in io.Reader, out io.Writer) error {
+ return trace.Wrap(mcpserver.NewStdioServer(s.MCPServer).Listen(ctx, in, out))
+}
+
+func buildDatabaseResource(db *Database) DatabaseResource {
+ return DatabaseResource{
+ Metadata: db.DB.GetMetadata(),
+ URI: db.ResourceURI().String(),
+ Protocol: db.DB.GetProtocol(),
+ ClusterName: db.ClusterName,
+ }
+}
+
+func encodeDatabaseResource(db *Database) (mcp.ResourceContents, error) {
+ resource := buildDatabaseResource(db)
+ out, err := yaml.Marshal(resource)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return mcp.TextResourceContents{
+ URI: resource.URI,
+ MIMEType: databaseResourceMIMEType,
+ Text: string(out),
+ }, nil
+}
+
+const (
+ // serverName is the database MCP server name.
+ serverName = "teleport_databases"
+ // listDatabasesTool is the list databases tool name.
+ listDatabasesToolName = ToolPrefix + "list_databases"
+ // databaseResourceMIMEType is the MIME type of the database resources.
+ databaseResourceMIMEType = "application/yaml"
+)
diff --git a/lib/client/db/mcp/server_test.go b/lib/client/db/mcp/server_test.go
new file mode 100644
index 0000000000000..007067d3507ea
--- /dev/null
+++ b/lib/client/db/mcp/server_test.go
@@ -0,0 +1,176 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "log/slog"
+ "slices"
+ "strings"
+ "testing"
+
+ "github.com/ghodss/yaml"
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ mcpclient "github.com/mark3labs/mcp-go/client"
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/defaults"
+)
+
+func TestRegisterDatabase(t *testing.T) {
+ server := NewRootServer(slog.New(slog.DiscardHandler))
+ databases := []*Database{
+ buildDatabase(t, "first"),
+ buildDatabase(t, "second"),
+ buildDatabase(t, "third"),
+ }
+ // sort databases by name to ensure the same order every test.
+ slices.SortFunc(databases, func(a, b *Database) int {
+ return strings.Compare(a.ResourceURI().String(), b.ResourceURI().String())
+ })
+
+ for _, db := range databases {
+ server.RegisterDatabase(db)
+ }
+
+ clt := buildClient(t, server)
+ t.Run("Resources", func(t *testing.T) {
+ listResult, err := clt.ListResources(t.Context(), mcp.ListResourcesRequest{})
+ require.NoError(t, err)
+ require.Len(t, listResult.Resources, len(databases))
+
+ // sort the result using the same field as databases to avoid flaky
+ // test.
+ slices.SortFunc(listResult.Resources, func(a, b mcp.Resource) int {
+ return strings.Compare(a.URI, b.URI)
+ })
+
+ for i, r := range listResult.Resources {
+ req := mcp.ReadResourceRequest{}
+ req.Params.URI = r.URI
+ readResult, err := clt.ReadResource(t.Context(), req)
+ require.NoError(t, err)
+ require.Len(t, readResult.Contents, 1)
+ assertDatabaseResource(t, databases[i], readResult.Contents[0])
+ }
+ })
+
+ t.Run("Tool", func(t *testing.T) {
+ req := mcp.CallToolRequest{}
+ req.Params.Name = listDatabasesToolName
+ res, err := clt.CallTool(t.Context(), req)
+ require.NoError(t, err)
+ require.False(t, res.IsError)
+ require.Len(t, res.Content, len(databases))
+
+ for _, c := range res.Content {
+ require.IsType(t, mcp.EmbeddedResource{}, c)
+ require.IsType(t, mcp.TextResourceContents{}, c.(mcp.EmbeddedResource).Resource)
+ }
+
+ // Although we're not sorting by the URI directly, the only field that
+ // is different across the databases is their name and URI (which would
+ // cause them to have the same order). So here we sort by the YAML
+ // contents to avoid having to decode.
+ slices.SortFunc(res.Content, func(a, b mcp.Content) int {
+ resourceA := a.(mcp.EmbeddedResource).Resource.(mcp.TextResourceContents)
+ resourceB := b.(mcp.EmbeddedResource).Resource.(mcp.TextResourceContents)
+ return strings.Compare(resourceA.Text, resourceB.Text)
+ })
+
+ for i, c := range res.Content {
+ content := c.(mcp.EmbeddedResource)
+ assertDatabaseResource(t, databases[i], content.Resource)
+ }
+ })
+}
+
+func TestEmptyDatabasesServer(t *testing.T) {
+ server := NewRootServer(slog.New(slog.DiscardHandler))
+
+ clt := buildClient(t, server)
+ t.Run("Resources", func(t *testing.T) {
+ _, err := clt.ListResources(t.Context(), mcp.ListResourcesRequest{})
+ require.Error(t, err)
+ })
+
+ t.Run("Tool", func(t *testing.T) {
+ req := mcp.CallToolRequest{}
+ req.Params.Name = listDatabasesToolName
+ res, err := clt.CallTool(t.Context(), req)
+ require.NoError(t, err)
+ require.True(t, res.IsError)
+
+ require.Len(t, res.Content, 1)
+ content := res.Content[0]
+ require.IsType(t, mcp.TextContent{}, content)
+ textError := content.(mcp.TextContent).Text
+ require.Contains(t, textError, EmptyDatabasesListErrorMessage, "expected empty databases error but got: %s", textError)
+ })
+}
+
+func assertDatabaseResource(t *testing.T, db *Database, resource mcp.ResourceContents) {
+ t.Helper()
+ require.IsType(t, mcp.TextResourceContents{}, resource)
+ contents := resource.(mcp.TextResourceContents)
+ var database DatabaseResource
+ require.Equal(t, databaseResourceMIMEType, contents.MIMEType)
+ require.NoError(t, yaml.Unmarshal([]byte(contents.Text), &database))
+ require.Empty(t, cmp.Diff(buildDatabaseResource(db), database, cmpopts.IgnoreFields(types.Metadata{}, "Namespace")))
+}
+
+func buildDatabase(t *testing.T, name string) *Database {
+ t.Helper()
+
+ db, err := types.NewDatabaseV3(types.Metadata{
+ Name: name,
+ Labels: map[string]string{"env": "test"},
+ }, types.DatabaseSpecV3{
+ Protocol: defaults.ProtocolPostgres,
+ URI: "localhost:5432",
+ })
+ require.NoError(t, err)
+
+ return &Database{
+ DB: db,
+ ClusterName: "root",
+ Addr: "localhost:5555",
+ }
+}
+
+func buildClient(t *testing.T, server *RootServer) *mcpclient.Client {
+ t.Helper()
+
+ clt, err := mcpclient.NewInProcessClient(server.MCPServer)
+ require.NoError(t, err)
+ t.Cleanup(func() { clt.Close() })
+ require.NoError(t, clt.Start(t.Context()))
+
+ initRequest := mcp.InitializeRequest{}
+ initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
+ initRequest.Params.ClientInfo = mcp.Implementation{
+ Name: "test-client",
+ Version: "1.0.0",
+ }
+
+ _, err = clt.Initialize(t.Context(), initRequest)
+ require.NoError(t, err)
+ require.NoError(t, clt.Ping(t.Context()))
+ return clt
+}
diff --git a/lib/client/db/postgres/mcp/mcp.go b/lib/client/db/postgres/mcp/mcp.go
new file mode 100644
index 0000000000000..6a6b8a19dc10b
--- /dev/null
+++ b/lib/client/db/postgres/mcp/mcp.go
@@ -0,0 +1,259 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "context"
+ "encoding/json"
+ "log/slog"
+ "net"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgxpool"
+ "github.com/mark3labs/mcp-go/mcp"
+
+ dbmcp "github.com/gravitational/teleport/lib/client/db/mcp"
+ clientmcp "github.com/gravitational/teleport/lib/client/mcp"
+ "github.com/gravitational/teleport/lib/defaults"
+)
+
+// queryTool is the run query MCP tool definition.
+var queryTool = mcp.NewTool(dbmcp.ToolName(defaults.ProtocolPostgres, "query"),
+ mcp.WithDescription("Execute SQL query against PostgreSQL database connected using Teleport"),
+ mcp.WithString(queryToolDatabaseParam,
+ mcp.Required(),
+ mcp.Description("Teleport database resource URI where the query will be executed"),
+ ),
+ mcp.WithString(queryToolQueryParam,
+ mcp.Required(),
+ mcp.Description("PostgresSQL SQL query to execute"),
+ ),
+)
+
+type database struct {
+ *dbmcp.Database
+ pool *pgxpool.Pool
+}
+
+// Server handles PostgreSQL-specific MCP tools requests.
+type Server struct {
+ logger *slog.Logger
+ databases map[string]*database
+}
+
+// NewServer initializes a PostgreSQL MCP server, creating the database
+// configurations and registering Server tools into the root server.
+func NewServer(ctx context.Context, cfg *dbmcp.NewServerConfig) (dbmcp.Server, error) {
+ s := &Server{logger: cfg.Logger, databases: make(map[string]*database)}
+
+ for _, db := range cfg.Databases {
+ if db.DatabaseUser == "" || db.DatabaseName == "" {
+ return nil, trace.BadParameter("database %q is missing the username and database name", db.DB.GetName())
+ }
+
+ connCfg, err := buildConnConfig(db)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ pool, err := pgxpool.NewWithConfig(ctx, connCfg)
+ if err != nil {
+ return nil, trace.BadParameter("failed to parse database %q connection config: %s", db.DB.GetName(), err)
+ }
+
+ s.databases[db.ResourceURI().String()] = &database{
+ Database: db,
+ pool: pool,
+ }
+ }
+
+ cfg.RootServer.AddTool(queryTool, s.RunQuery)
+ return s, nil
+}
+
+// Close implements dbmcp.Server.
+func (s *Server) Close(context.Context) error {
+ for _, db := range s.databases {
+ db.pool.Close()
+ }
+
+ return nil
+}
+
+// RunQueryResult is the run query tool result.
+type RunQueryResult struct {
+ // Data contains the data returned from the query. It can be empty in case
+ // the query doesn't return any data.
+ Data []map[string]any `json:"data"`
+ // RowsCount number of rows affected by the query or returned as data.
+ RowsCount int `json:"rowsCount"`
+ // ErrorMessage if the query wasn't successful, this field contains the
+ // error message.
+ ErrorMessage string `json:"error,omitempty"`
+}
+
+// RunQuery tool function used to execute queries on databases.
+func (s *Server) RunQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ uri, err := request.RequireString(queryToolDatabaseParam)
+ if err != nil {
+ return s.wrapErrorResult(ctx, nil, trace.Wrap(err))
+ }
+
+ sql, err := request.RequireString(queryToolQueryParam)
+ if err != nil {
+ return s.wrapErrorResult(ctx, nil, trace.Wrap(err))
+ }
+
+ db, err := s.getDatabase(uri)
+ if err != nil {
+ return s.wrapErrorResult(ctx, nil, err)
+ }
+
+ // TODO(gabrielcorado): ensure the connection used is consistent for the
+ // session, making most of its queries to be present in a single audit
+ // session/recording.
+ rows, err := db.pool.Query(ctx, sql)
+ if err != nil {
+ return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err)
+ }
+
+ // Returned rows are being closed by this function.
+ result, err := buildQueryResult(rows)
+ if err != nil {
+ return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err)
+ }
+
+ return mcp.NewToolResultText(result), nil
+}
+
+func (s *Server) wrapErrorResult(ctx context.Context, externalRetriever dbmcp.ExternalErrorRetriever, toolErr error) (*mcp.CallToolResult, error) {
+ s.logger.ErrorContext(ctx, "error while querying database", "error", toolErr)
+ out, err := json.Marshal(RunQueryResult{ErrorMessage: dbmcp.FormatErrorMessage(externalRetriever, toolErr).Error()})
+ return mcp.NewToolResultError(string(out)), trace.Wrap(err)
+}
+
+// buildQueryResult takes a the response from pgx and converts into a JSON
+// format (which will be returned to LLMs).
+func buildQueryResult(rows pgx.Rows) (string, error) {
+ // Just ensure the rows is always closed. It is safe if this is called
+ // multiple times.
+ defer rows.Close()
+
+ var data []map[string]any
+ columns := rows.FieldDescriptions()
+
+ for rows.Next() {
+ values, err := rows.Values()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ item := make(map[string]any, len(values))
+ for i, v := range values {
+ item[columns[i].Name] = v
+ }
+
+ data = append(data, item)
+ }
+
+ // Close the rows to finish consuming it. Depending on the its type
+ // we can only collect the command tag after rows is closed.
+ rows.Close()
+ commandTag := rows.CommandTag()
+
+ // Initialize the slice so the resulting JSON will have an empty array
+ // instead of null.
+ if len(data) == 0 && commandTag.Select() {
+ data = []map[string]any{}
+ }
+
+ out, err := json.Marshal(RunQueryResult{
+ Data: data,
+ RowsCount: int(commandTag.RowsAffected()),
+ })
+ return string(out), trace.Wrap(err)
+}
+
+func (s *Server) getDatabase(uri string) (*database, error) {
+ if !clientmcp.IsDatabaseResourceURI(uri) {
+ return nil, dbmcp.WrongDatabaseURIFormatError
+ }
+
+ db, ok := s.databases[uri]
+ if !ok {
+ return nil, dbmcp.DatabaseNotFoundError
+ }
+
+ return db, nil
+}
+
+func buildConnConfig(db *dbmcp.Database) (*pgxpool.Config, error) {
+ config, err := pgxpool.ParseConfig("postgres://" + db.Addr)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ config.MaxConnIdleTime = connectionIdleTime
+ config.MaxConns = int32(maxConnections)
+
+ config.ConnConfig.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
+ return db.LookupFunc(ctx, host)
+ }
+ config.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return db.DialContextFunc(ctx, network, addr)
+ }
+
+ config.ConnConfig.User = db.DatabaseUser
+ config.ConnConfig.Database = db.DatabaseName
+ config.ConnConfig.ConnectTimeout = defaults.DatabaseConnectTimeout
+ config.ConnConfig.RuntimeParams = map[string]string{
+ applicationNameParamName: applicationNameParamValue,
+ }
+ config.ConnConfig.TLSConfig = nil
+ // Use simple protocol to have a closer behavior to DB REPL and psql.
+ //
+ // This also avoids each query being prepared, binded and executed, reducing
+ // the amount of audit events per query executed.
+ config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
+ return config, nil
+}
+
+const (
+ // queryToolDatabaseParam is the name of the database URI param name from
+ // query tool.
+ queryToolDatabaseParam = "database"
+ // queryToolQueryParam is the name of the query param name from query tool.
+ queryToolQueryParam = "query"
+
+ // applicationNameParamName defines the application name parameter name.
+ //
+ // https://www.postgresql.org/docs/17/libpq-connect.html#LIBPQ-CONNECT-APPLICATION-NAME
+ applicationNameParamName = "application_name"
+ // applicationNameParamValue defines the application name parameter value.
+ applicationNameParamValue = "teleport-mcp"
+ // connectionIdleTime is the max connection idle time before it gets closed
+ // automatically.
+ connectionIdleTime = 1 * time.Minute
+ // maxConnections defines the max number of concurrent connections the pool
+ // can have.
+ //
+ // Given the current MCP usage, the clients will most likely do one query at
+ // time, even on multiple sessions.
+ maxConnections = 1
+)
diff --git a/lib/client/db/postgres/mcp/mcp_test.go b/lib/client/db/postgres/mcp/mcp_test.go
new file mode 100644
index 0000000000000..efdf69f5804e6
--- /dev/null
+++ b/lib/client/db/postgres/mcp/mcp_test.go
@@ -0,0 +1,234 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "context"
+ "encoding/json"
+ "log/slog"
+ "testing"
+
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgconn"
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/types"
+ dbmcp "github.com/gravitational/teleport/lib/client/db/mcp"
+ clientmcp "github.com/gravitational/teleport/lib/client/mcp"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/utils/listener"
+)
+
+func TestFormatResult(t *testing.T) {
+ for name, tc := range map[string]struct {
+ rows pgx.Rows
+ expectedResult string
+ }{
+ "query results": {
+ rows: newMockRows("SELECT 2", []string{"name", "age"}, [][]any{{"Alice", 30}, {"Bob", 31}}),
+ expectedResult: `{"data":[{"age":30,"name":"Alice"},{"age":31,"name":"Bob"}],"rowsCount":2}`,
+ },
+ "empty query results": {
+ rows: newMockRows("SELECT 0", []string{}, [][]any{}),
+ expectedResult: `{"data":[],"rowsCount":0}`,
+ },
+ "non-data results": {
+ rows: newMockRows("INSERT 1", []string{}, [][]any{}),
+ expectedResult: `{"data":null,"rowsCount":1}`,
+ },
+ } {
+ t.Run(name, func(t *testing.T) {
+ res, err := buildQueryResult(tc.rows)
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedResult, res)
+ })
+ }
+}
+
+func TestFormatErrors(t *testing.T) {
+ // Dummy listener that always drop connections.
+ listener := listener.NewInMemoryListener()
+ t.Cleanup(func() { listener.Close() })
+
+ go func() {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ return
+ }
+ _ = conn.Close()
+ }
+ }()
+
+ dbName := "local"
+ db, err := types.NewDatabaseV3(types.Metadata{
+ Name: dbName,
+ Labels: map[string]string{"env": "test"},
+ }, types.DatabaseSpecV3{
+ Protocol: defaults.ProtocolPostgres,
+ URI: "localhost:5432",
+ })
+ require.NoError(t, err)
+ dbURI := clientmcp.NewDatabaseResourceURI("root", dbName).String()
+
+ for name, tc := range map[string]struct {
+ databaseURI string
+ databases []*dbmcp.Database
+ externalErrorRetriever dbmcp.ExternalErrorRetriever
+ expectErrorMessage require.ValueAssertionFunc
+ }{
+ "database not found": {
+ databaseURI: "teleport://clusters/root/databases/not-found",
+ expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) {
+ require.Equal(t, dbmcp.DatabaseNotFoundError.Error(), i1)
+ },
+ },
+ "malformed database uri": {
+ databaseURI: "not-found",
+ expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) {
+ require.Equal(t, dbmcp.WrongDatabaseURIFormatError.Error(), i1)
+ },
+ },
+ "local proxy rejects connection": {
+ databaseURI: dbURI,
+ databases: []*dbmcp.Database{
+ &dbmcp.Database{
+ DB: db,
+ ClusterName: "root",
+ DatabaseUser: "postgres",
+ DatabaseName: "postgres",
+ Addr: listener.Addr().String(),
+ LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) {
+ return []string{"memory"}, nil
+ },
+ DialContextFunc: listener.DialContext,
+ },
+ },
+ expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) {
+ require.Equal(t, dbmcp.LocalProxyConnectionErrorMessage, i1)
+ },
+ },
+ "relogin error": {
+ databaseURI: dbURI,
+ databases: []*dbmcp.Database{
+ &dbmcp.Database{
+ DB: db,
+ ClusterName: "root",
+ DatabaseUser: "postgres",
+ DatabaseName: "postgres",
+ Addr: listener.Addr().String(),
+ ExternalErrorRetriever: &mockErrorRetriever{err: client.ErrClientCredentialsHaveExpired},
+ LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) {
+ return []string{"memory"}, nil
+ },
+ DialContextFunc: listener.DialContext,
+ },
+ },
+ expectErrorMessage: func(tt require.TestingT, i1 interface{}, i2 ...interface{}) {
+ require.Equal(t, dbmcp.ReloginRequiredErrorMessage, i1)
+ },
+ },
+ } {
+ t.Run(name, func(t *testing.T) {
+ logger := slog.New(slog.DiscardHandler)
+ rootServer := dbmcp.NewRootServer(logger)
+ srv, err := NewServer(t.Context(), &dbmcp.NewServerConfig{
+ Logger: logger,
+ RootServer: rootServer,
+ Databases: tc.databases,
+ })
+ require.NoError(t, err)
+ t.Cleanup(func() { srv.Close(t.Context()) })
+
+ pgSrv := srv.(*Server)
+ req := mcp.CallToolRequest{}
+ req.Params.Arguments = map[string]any{
+ queryToolDatabaseParam: tc.databaseURI,
+ queryToolQueryParam: "SELECT 1",
+ }
+ runResult, err := pgSrv.RunQuery(t.Context(), req)
+ require.NoError(t, err)
+
+ require.True(t, runResult.IsError)
+ require.Len(t, runResult.Content, 1)
+ require.IsType(t, mcp.TextContent{}, runResult.Content[0])
+
+ content := runResult.Content[0].(mcp.TextContent)
+ var res RunQueryResult
+ require.NoError(t, json.Unmarshal([]byte(content.Text), &res), "expected result to be in JSON format")
+ require.Empty(t, res.Data)
+ tc.expectErrorMessage(t, res.ErrorMessage)
+ })
+ }
+}
+
+func newMockRows(commandTag string, fields []string, rows [][]any) pgx.Rows {
+ var fds []pgconn.FieldDescription
+ for _, fieldName := range fields {
+ fds = append(fds, pgconn.FieldDescription{Name: fieldName})
+ }
+ return &mockRows{
+ commandTag: commandTag,
+ descriptions: fds,
+ rows: rows,
+ }
+}
+
+type mockRows struct {
+ pgx.Rows
+
+ started bool
+ cursor int
+
+ commandTag string
+ descriptions []pgconn.FieldDescription
+ rows [][]any
+}
+
+func (mr *mockRows) FieldDescriptions() []pgconn.FieldDescription {
+ return mr.descriptions
+}
+
+func (mr *mockRows) Next() bool {
+ if !mr.started {
+ mr.started = true
+ return len(mr.rows) > 0
+ }
+
+ mr.cursor += 1
+ return len(mr.rows) > mr.cursor
+}
+
+func (mr *mockRows) Values() ([]any, error) {
+ return mr.rows[mr.cursor], nil
+}
+
+func (mr *mockRows) CommandTag() pgconn.CommandTag {
+ return pgconn.NewCommandTag(mr.commandTag)
+}
+
+func (mr *mockRows) Close() {}
+
+type mockErrorRetriever struct {
+ err error
+}
+
+func (mr *mockErrorRetriever) RetrieveError() error {
+ return mr.err
+}
diff --git a/lib/client/local_proxy_middleware.go b/lib/client/local_proxy_middleware.go
index 1f6f7bf9622c1..9f4fd19670d67 100644
--- a/lib/client/local_proxy_middleware.go
+++ b/lib/client/local_proxy_middleware.go
@@ -52,6 +52,9 @@ type CertChecker struct {
cert tls.Certificate
certMu sync.Mutex
+
+ err error
+ errMu sync.Mutex
}
var _ alpnproxy.LocalProxyMiddleware = (*CertChecker)(nil)
@@ -142,15 +145,19 @@ func (c *CertChecker) SetCert(cert tls.Certificate) {
// GetOrIssueCert gets the CertChecker's certificate, or issues a new
// certificate if the it is invalid (e.g. expired) or missing.
-func (c *CertChecker) GetOrIssueCert(ctx context.Context) (tls.Certificate, error) {
+func (c *CertChecker) GetOrIssueCert(ctx context.Context) (cert tls.Certificate, err error) {
c.certMu.Lock()
defer c.certMu.Unlock()
+ defer func() {
+ c.setError(err)
+ }()
+
if err := c.checkCert(); err == nil {
return c.cert, nil
}
- cert, err := c.certIssuer.IssueCert(ctx)
+ cert, err = c.certIssuer.IssueCert(ctx)
if err != nil {
return tls.Certificate{}, trace.Wrap(err)
}
@@ -170,6 +177,13 @@ func (c *CertChecker) GetOrIssueCert(ctx context.Context) (tls.Certificate, erro
return c.cert, nil
}
+// RetrieveError retrieves the happened on while retrieving certificates.
+func (c *CertChecker) RetrieveError() error {
+ c.errMu.Lock()
+ defer c.errMu.Unlock()
+ return c.err
+}
+
func (c *CertChecker) checkCert() error {
leaf, err := utils.TLSCertLeaf(c.cert)
if err != nil {
@@ -184,6 +198,12 @@ func (c *CertChecker) checkCert() error {
return trace.Wrap(c.certIssuer.CheckCert(leaf))
}
+func (c *CertChecker) setError(err error) {
+ c.errMu.Lock()
+ defer c.errMu.Unlock()
+ c.err = err
+}
+
// CertIssuer checks and issues certs.
type CertIssuer interface {
// CheckCert checks that an existing certificate is valid.
diff --git a/lib/client/local_proxy_middleware_test.go b/lib/client/local_proxy_middleware_test.go
index cfb79e335601e..48d99c859c39d 100644
--- a/lib/client/local_proxy_middleware_test.go
+++ b/lib/client/local_proxy_middleware_test.go
@@ -44,10 +44,12 @@ func TestCertChecker(t *testing.T) {
// certChecker should issue a new cert on first request.
cert, err := certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
+ require.NoError(t, certChecker.RetrieveError())
// subsequent calls should return the same cert.
sameCert, err := certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
+ require.NoError(t, certChecker.RetrieveError())
require.Equal(t, cert, sameCert)
// If the current cert expires it should be reissued.
@@ -56,6 +58,7 @@ func TestCertChecker(t *testing.T) {
cert, err = certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
+ require.NoError(t, certChecker.RetrieveError())
require.NotEqual(t, cert, expiredCert)
// If the current cert fails certIssuer checks, a new one should be issued.
@@ -64,12 +67,20 @@ func TestCertChecker(t *testing.T) {
cert, err = certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
+ require.NoError(t, certChecker.RetrieveError())
require.NotEqual(t, cert, badCert)
// If issuing a new cert fails, an error is returned.
certIssuer.issueErr = trace.BadParameter("failed to issue cert")
_, err = certChecker.GetOrIssueCert(ctx)
require.ErrorIs(t, err, certIssuer.issueErr, "expected error %v but got %v", certIssuer.issueErr, err)
+ require.ErrorIs(t, certChecker.RetrieveError(), err, "expected retrieve error to be the same get error but got: %v", certChecker.RetrieveError())
+
+ // If the problem is solved, the error is clean up.
+ certIssuer.issueErr = nil
+ _, err = certChecker.GetOrIssueCert(ctx)
+ require.NoError(t, err)
+ require.NoError(t, certChecker.RetrieveError())
}
func TestLocalCertGenerator(t *testing.T) {
diff --git a/lib/client/mcp/uri.go b/lib/client/mcp/uri.go
new file mode 100644
index 0000000000000..a932a67c776ec
--- /dev/null
+++ b/lib/client/mcp/uri.go
@@ -0,0 +1,153 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "net/url"
+ "strings"
+
+ "github.com/gravitational/trace"
+ "github.com/ucarion/urlpath"
+)
+
+var (
+ // clusterURITemplate is the base cluster template.
+ clusterURITemplate = urlpath.New("/clusters/:cluster/*")
+ // databaseURITemplate template used to parse database resource URIs.
+ databaseURITemplate = urlpath.New("/clusters/:cluster/databases/:dbName")
+)
+
+const (
+ // resourceScheme scheme used by Teleport MCP resources.
+ resourceScheme = "teleport"
+
+ // databaseNameQueryParamName is the query param name used for database
+ // name.
+ databaseNameQueryParamName = "dbName"
+ // databaseUserQueryParamName is the query param name used for database
+ // user.
+ databaseUserQueryParamName = "dbUser"
+)
+
+// ResourceURI is a Teleport MCP resource URI.
+//
+// Query parameters are not covered on the MCP spec but we use them to provide
+// additional information about the resource connection. For example, if the
+// resource requires a "username" value, this value is provided using the query
+// params.
+//
+// https://modelcontextprotocol.io/docs/concepts/resources#resource-uris
+type ResourceURI struct {
+ url url.URL
+}
+
+// ParseResourceURI parses a raw resource URI into a Teleport URI.
+func ParseResourceURI(uri string) (*ResourceURI, error) {
+ parsedURL, err := url.Parse(uri)
+ if err != nil {
+ return nil, trace.BadParameter("invalid resource URI format: %s", err)
+ }
+
+ if parsedURL.Scheme != resourceScheme {
+ return nil, trace.BadParameter("invalid URI scheme, must be %q", resourceScheme)
+ }
+
+ return &ResourceURI{url: *parsedURL}, nil
+}
+
+// NewDatabaseResourceURI creates a new database resource URI.
+func NewDatabaseResourceURI(cluster, databaseName string) ResourceURI {
+ pathWithHost, _ := databaseURITemplate.Build(urlpath.Match{
+ Params: map[string]string{
+ "cluster": cluster,
+ "dbName": databaseName,
+ },
+ })
+
+ return ResourceURI{
+ url: url.URL{
+ Scheme: resourceScheme,
+ Path: strings.TrimPrefix(pathWithHost, "/"),
+ },
+ }
+}
+
+// GetDatabaseServiceName returns the Teleport cluster name.
+func (u ResourceURI) GetClusterName() string {
+ if match, ok := clusterURITemplate.Match(u.path()); ok {
+ return match.Params["cluster"]
+ }
+
+ return ""
+}
+
+// GetDatabaseServiceName returns the database service name of the resource.
+// Returns empty if the resource is not a database.
+func (u ResourceURI) GetDatabaseServiceName() string {
+ if match, ok := databaseURITemplate.Match(u.path()); ok {
+ return match.Params["dbName"]
+ }
+
+ return ""
+}
+
+// GetDatabaseUser returns the database username param of the resource.
+// Returns empty if the resource is not a database.
+func (u ResourceURI) GetDatabaseUser() string {
+ return u.url.Query().Get(databaseUserQueryParamName)
+}
+
+// GetDatabaseName returns the database name param of the resource.
+// Returns empty if the resource is not a database.
+func (u ResourceURI) GetDatabaseName() string {
+ return u.url.Query().Get(databaseNameQueryParamName)
+}
+
+// IsDatabase returns true if the resource is a database.
+func (u ResourceURI) IsDatabase() bool {
+ return u.GetDatabaseServiceName() != ""
+}
+
+// String returns the string representation of the resource URI (excluding the
+// query params).
+func (u ResourceURI) String() string {
+ c := u.url
+ c.RawQuery = ""
+ return c.String()
+}
+
+// path returns the resource URI full path. We must include the hostname as the
+// templates will also include them on the matching.
+func (u ResourceURI) path() string {
+ return "/" + u.url.Hostname() + u.url.Path
+}
+
+// IsDatabase returns true if the URI is a database resource.
+func IsDatabaseResourceURI(uri string) bool {
+ parsed, err := ParseResourceURI(uri)
+ if err != nil {
+ return false
+ }
+
+ return parsed.IsDatabase()
+}
+
+var (
+ // SampleDatabaseResource contains a sample full resource URI. This can be
+ // used on descriptions to show how a database resource URI looks like.
+ SampleDatabaseResource = NewDatabaseResourceURI("example-cluster", "myDatabase")
+)
diff --git a/lib/client/mcp/uri_test.go b/lib/client/mcp/uri_test.go
new file mode 100644
index 0000000000000..30cdc64a2a357
--- /dev/null
+++ b/lib/client/mcp/uri_test.go
@@ -0,0 +1,94 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package mcp
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestDatabaseResourceURI(t *testing.T) {
+ for name, tc := range map[string]struct {
+ uri string
+ expectError bool
+ expectedDatabase bool
+ expectedServiceName string
+ expectedDatabaseName string
+ expectedDatabaseUser string
+ expectedClusterName string
+ }{
+ "valid database": {
+ uri: "teleport://clusters/default/databases/pg?dbName=database&dbUser=user",
+ expectedDatabase: true,
+ expectedServiceName: "pg",
+ expectedDatabaseName: "database",
+ expectedDatabaseUser: "user",
+ expectedClusterName: "default",
+ },
+ "valid database without params": {
+ uri: "teleport://clusters/default/databases/pg",
+ expectedDatabase: true,
+ expectedServiceName: "pg",
+ expectedDatabaseName: "",
+ expectedDatabaseUser: "",
+ expectedClusterName: "default",
+ },
+ "random resource": {
+ uri: "teleport://clusters/default/random/random-resource",
+ expectedDatabase: false,
+ expectedServiceName: "",
+ expectedDatabaseName: "",
+ expectedDatabaseUser: "",
+ expectedClusterName: "default",
+ },
+ "generated uri": {
+ uri: NewDatabaseResourceURI("default", "db").String(),
+ expectedDatabase: true,
+ expectedServiceName: "db",
+ expectedDatabaseName: "",
+ expectedDatabaseUser: "",
+ expectedClusterName: "default",
+ },
+ "invalid schema": {
+ uri: "http://databases/database",
+ expectError: true,
+ },
+ "invalid uri": {
+ uri: "random-value",
+ expectError: true,
+ },
+ } {
+ t.Run(name, func(t *testing.T) {
+ uri, err := ParseResourceURI(tc.uri)
+ if tc.expectError {
+ require.Error(t, err)
+ return
+ }
+
+ require.NotNil(t, uri)
+ fmt.Println(tc.uri)
+ require.Equal(t, tc.expectedDatabase, IsDatabaseResourceURI(tc.uri))
+ require.Equal(t, tc.expectedDatabase, uri.IsDatabase())
+ require.Equal(t, tc.expectedServiceName, uri.GetDatabaseServiceName())
+ require.Equal(t, tc.expectedDatabaseName, uri.GetDatabaseName())
+ require.Equal(t, tc.expectedDatabaseUser, uri.GetDatabaseUser())
+ require.Equal(t, tc.expectedClusterName, uri.GetClusterName())
+ })
+ }
+}
diff --git a/lib/utils/cli.go b/lib/utils/cli.go
index a025a301a350b..6a8ca6067ead6 100644
--- a/lib/utils/cli.go
+++ b/lib/utils/cli.go
@@ -54,6 +54,8 @@ const (
LoggingForDaemon LoggingPurpose = iota
// LoggingForCLI configures logging for user face utilities (tctl, tsh).
LoggingForCLI
+ // LoggingForMCP configures logging for MCP servers.
+ LoggingForMCP
)
// LoggingFormat defines the possible logging output formats.
@@ -100,7 +102,7 @@ func IsTerminal(w io.Writer) bool {
}
// InitLogger configures the global logger for a given purpose / verbosity level
-func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption) error {
+func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption) (*slog.Logger, error) {
var o logOpts
for _, opt := range opts {
@@ -110,23 +112,28 @@ func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption)
// If debug or trace logging is not enabled for CLIs,
// then discard all log output.
if purpose == LoggingForCLI && level > slog.LevelDebug {
- slog.SetDefault(slog.New(slog.DiscardHandler))
- return nil
+ logger := slog.New(slog.DiscardHandler)
+ slog.SetDefault(logger)
+ return logger, nil
}
var output string
- if o.osLogSubsystem != "" {
+ switch {
+ case o.osLogSubsystem != "":
output = logutils.LogOutputOSLog
+ case purpose == LoggingForMCP:
+ output = logutils.LogOutputMCP
+ o.format = LogFormatJSON
}
- _, _, err := logutils.Initialize(logutils.Config{
+ logger, _, err := logutils.Initialize(logutils.Config{
Severity: level.String(),
Format: o.format,
EnableColors: IsTerminal(os.Stderr),
Output: output,
OSLogSubsystem: o.osLogSubsystem,
})
- return trace.Wrap(err)
+ return logger, trace.Wrap(err)
}
var initTestLoggerOnce = sync.Once{}
diff --git a/lib/utils/listener/memory.go b/lib/utils/listener/memory.go
new file mode 100644
index 0000000000000..5f7795a69952d
--- /dev/null
+++ b/lib/utils/listener/memory.go
@@ -0,0 +1,109 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package listener
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "sync"
+)
+
+// InMemoryListener is a in-memory implementation of a net.Listener.
+type InMemoryListener struct {
+ connCh chan net.Conn
+ closeCh chan struct{}
+ closeOnce sync.Once
+}
+
+// Accept implements net.Listener.
+func (m *InMemoryListener) Accept() (net.Conn, error) {
+ select {
+ case <-m.closeCh:
+ return nil, io.EOF
+ default:
+ }
+
+ for {
+ select {
+ case conn := <-m.connCh:
+ return conn, nil
+ case <-m.closeCh:
+ return nil, io.EOF
+ }
+ }
+}
+
+// Addr implements net.Listener.
+func (m *InMemoryListener) Addr() net.Addr {
+ return defaultMemoryAddr
+}
+
+// Close implements net.Listener.
+func (m *InMemoryListener) Close() error {
+ m.closeOnce.Do(func() { close(m.closeCh) })
+ return nil
+}
+
+// DialContext dials the memory listener, creating a new net.Conn.
+//
+// This function satisfies net.Dialer.DialContext signature.
+func (m *InMemoryListener) DialContext(ctx context.Context, _ string, _ string) (net.Conn, error) {
+ select {
+ case <-m.closeCh:
+ return nil, ErrListenerClosed
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ serverConn, clientConn := net.Pipe()
+
+ select {
+ case m.connCh <- serverConn:
+ case <-ctx.Done():
+ // In this case the connection was not accepted in time by the server
+ // and the dial context is done. To avoid having the server using an
+ // orphned connection we should close it.
+ _ = serverConn.Close()
+ return nil, ctx.Err()
+ }
+ return clientConn, nil
+}
+
+// ErrListenerClosed is the error returned by dial when the listener is closed.
+var ErrListenerClosed = errors.New("in-memory listener closed")
+
+// NewInMemoryListener initializes a new in-memory listener.
+func NewInMemoryListener() *InMemoryListener {
+ return &InMemoryListener{
+ connCh: make(chan net.Conn),
+ closeCh: make(chan struct{}),
+ }
+}
+
+var _ net.Listener = (*InMemoryListener)(nil)
+
+type memoryAddr string
+
+func (m memoryAddr) Network() string { return string(m) }
+func (m memoryAddr) String() string { return string(m) }
+
+var defaultMemoryAddr = memoryAddr("memory")
+
+var _ net.Addr = (*memoryAddr)(nil)
diff --git a/lib/utils/listener/memory_test.go b/lib/utils/listener/memory_test.go
new file mode 100644
index 0000000000000..024a96031de16
--- /dev/null
+++ b/lib/utils/listener/memory_test.go
@@ -0,0 +1,132 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package listener
+
+import (
+ "context"
+ "io"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMemoryListenerClient(t *testing.T) {
+ var wg sync.WaitGroup
+ expectedMessage := "hello from server"
+
+ listener := NewInMemoryListener()
+ t.Cleanup(func() { listener.Close() })
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ conn, err := listener.Accept()
+ if err != nil {
+ return
+ }
+ t.Cleanup(func() { conn.Close() })
+
+ _, _ = conn.Write([]byte(expectedMessage))
+ }()
+
+ // To avoid blocking in case the server is not working correctly, wrap
+ // the client connection into a eventually loop.
+ require.EventuallyWithT(t, func(collect *assert.CollectT) {
+ conn, err := listener.DialContext(t.Context(), "", "")
+ require.NoError(collect, err)
+
+ buf := make([]byte, len(expectedMessage))
+ n, err := conn.Read(buf[0:])
+ require.NoError(collect, err)
+ require.Equal(collect, len(expectedMessage), n)
+ require.Equal(collect, expectedMessage, string(buf[:n]))
+ }, 50*time.Millisecond, 10*time.Millisecond)
+
+ require.Eventually(t, func() bool {
+ wg.Wait()
+ return true
+ }, 50*time.Millisecond, 10*time.Millisecond)
+}
+
+func TestMemoryListenerServer(t *testing.T) {
+ var wg sync.WaitGroup
+ expectedMessage := "hello from client"
+
+ listener := NewInMemoryListener()
+ t.Cleanup(func() { listener.Close() })
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ conn, err := listener.DialContext(t.Context(), "", "")
+ if err != nil {
+ return
+ }
+
+ _, _ = conn.Write([]byte(expectedMessage))
+ }()
+
+ // To avoid blocking in case the client is not working correctly, wrap
+ // the server accept connection into a eventually loop.
+ require.EventuallyWithT(t, func(collect *assert.CollectT) {
+ conn, err := listener.Accept()
+ require.NoError(collect, err)
+
+ buf := make([]byte, len(expectedMessage))
+ n, err := conn.Read(buf[0:])
+ require.NoError(collect, err)
+ require.Equal(collect, len(expectedMessage), n)
+ require.Equal(collect, expectedMessage, string(buf[:n]))
+ }, 50*time.Millisecond, 10*time.Millisecond)
+
+ // Close the listener and expect subsequent accept calls to return error.
+ listener.Close()
+ require.EventuallyWithT(t, func(collect *assert.CollectT) {
+ _, err := listener.Accept()
+ require.Error(collect, err)
+ require.ErrorIs(collect, err, io.EOF)
+ }, 50*time.Millisecond, 10*time.Millisecond)
+
+ require.Eventually(t, func() bool {
+ wg.Wait()
+ return true
+ }, 50*time.Millisecond, 10*time.Millisecond)
+}
+
+func TestMemoryListenerDialTimeout(t *testing.T) {
+ listener := NewInMemoryListener()
+ t.Cleanup(func() { listener.Close() })
+
+ ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond)
+ defer cancel()
+
+ require.EventuallyWithT(t, func(collect *assert.CollectT) {
+ _, err := listener.DialContext(ctx, "", "")
+ require.Error(collect, err)
+ require.ErrorIs(collect, err, context.DeadlineExceeded)
+ }, 100*time.Millisecond, 10*time.Millisecond)
+
+ require.Never(t, func() bool {
+ _, _ = listener.Accept()
+ return true
+ }, 50*time.Millisecond, 10*time.Millisecond, "expected server to not have received connections")
+}
diff --git a/lib/utils/log/log.go b/lib/utils/log/log.go
index 3e68086500e29..48f945c12162c 100644
--- a/lib/utils/log/log.go
+++ b/lib/utils/log/log.go
@@ -57,6 +57,10 @@ const (
LogOutputSyslog = "syslog"
// LogOutputOSLog represents os_log, the unified logging system on macOS, as the destination for logs.
LogOutputOSLog = "os_log"
+ // LogOutputMCP defines to where the MCP command logs will be directed to.
+ // The stdout is exclusively used as the MCP server transport, leaving only
+ // stderr available.
+ LogOutputMCP = "stderr"
)
// Initialize configures the default global logger based on the
diff --git a/tool/tsh/common/logger.go b/tool/tsh/common/logger.go
index df3e97595d259..c3450a5acc51f 100644
--- a/tool/tsh/common/logger.go
+++ b/tool/tsh/common/logger.go
@@ -38,7 +38,7 @@ const (
// It is called twice, first soon after launching tsh before argv is parsed and then again after
// kingpin parses argv. This makes it possible to debug early startup functionality, particularly
// command aliases.
-func initLogger(cf *CLIConf, opts loggingOpts) error {
+func initLogger(cf *CLIConf, purpose utils.LoggingPurpose, opts loggingOpts) (*slog.Logger, error) {
cf.OSLog = opts.osLog
cf.Debug = opts.debug || opts.osLog
@@ -49,7 +49,8 @@ func initLogger(cf *CLIConf, opts loggingOpts) error {
level = slog.LevelDebug
}
- return trace.Wrap(utils.InitLogger(utils.LoggingForCLI, level, initLoggerOpts...))
+ logger, err := utils.InitLogger(purpose, level, initLoggerOpts...)
+ return logger, trace.Wrap(err)
}
type loggingOpts struct {
diff --git a/tool/tsh/common/mcp.go b/tool/tsh/common/mcp.go
new file mode 100644
index 0000000000000..60f08d445233d
--- /dev/null
+++ b/tool/tsh/common/mcp.go
@@ -0,0 +1,31 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package common
+
+import "github.com/alecthomas/kingpin/v2"
+
+type mcpCommands struct {
+ dbStart *mcpDBStartCommand
+}
+
+func newMCPCommands(app *kingpin.Application) *mcpCommands {
+ mcp := app.Command("mcp", "View and control proxied MCP servers.")
+ db := mcp.Command("db", "Database access for MCP servers.")
+ return &mcpCommands{
+ dbStart: newMCPDBCommand(db),
+ }
+}
diff --git a/tool/tsh/common/mcp_db.go b/tool/tsh/common/mcp_db.go
new file mode 100644
index 0000000000000..699ae979c3ca5
--- /dev/null
+++ b/tool/tsh/common/mcp_db.go
@@ -0,0 +1,234 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package common
+
+import (
+ "context"
+ "log/slog"
+
+ "github.com/alecthomas/kingpin/v2"
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/client"
+ dbmcp "github.com/gravitational/teleport/lib/client/db/mcp"
+ pgmcp "github.com/gravitational/teleport/lib/client/db/postgres/mcp"
+ "github.com/gravitational/teleport/lib/client/mcp"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/srv/alpnproxy"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+ "github.com/gravitational/teleport/lib/utils/listener"
+)
+
+// mcpDBStartCommand implements `tsh mcp db start` command.
+type mcpDBStartCommand struct {
+ *kingpin.CmdClause
+
+ databaseURIs []string
+}
+
+func newMCPDBCommand(parent *kingpin.CmdClause) *mcpDBStartCommand {
+ cmd := &mcpDBStartCommand{
+ CmdClause: parent.Command("start", "Start a local MCP server for database access").Hidden(),
+ }
+
+ cmd.Arg("uris", "List of database MCP resource URIs that will be served by the server").Required().StringsVar(&cmd.databaseURIs)
+ return cmd
+}
+
+func (c *mcpDBStartCommand) run(cf *CLIConf) error {
+ logger, err := initLogger(cf, utils.LoggingForMCP, parseLoggingOptsFromEnvAndArgv(cf))
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ registry := defaultDBMCPRegistry
+ if cf.databaseMCPRegistryOverride != nil {
+ registry = cf.databaseMCPRegistryOverride
+ }
+
+ tc, err := makeClient(cf)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Avoid any input request on the command execution. This is required,
+ // otherwise the MCP clients will be stuck waiting for a response.
+ tc.NonInteractive = false
+
+ configuredDatabases := map[string]struct{}{}
+ uris := make([]*mcp.ResourceURI, len(c.databaseURIs))
+ for i, rawURI := range c.databaseURIs {
+ uri, err := mcp.ParseResourceURI(rawURI)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ if !uri.IsDatabase() {
+ return trace.BadParameter("%q resource must be a database", rawURI)
+ }
+
+ // TODO(gabrielcorado): support databases from different clusters.
+ if uri.GetClusterName() != tc.SiteName {
+ return trace.BadParameter("Databases must be from the same cluster (%q). %q is from a different cluster.", tc.SiteName, rawURI)
+ }
+
+ if _, ok := configuredDatabases[uri.String()]; ok {
+ return trace.BadParameter("Database %q was configured twice. MCP servers only support serving a database service only once.", uri.String())
+ }
+
+ configuredDatabases[uri.String()] = struct{}{}
+ uris[i] = uri
+ }
+
+ server := dbmcp.NewRootServer(logger)
+ allDatabases, closeLocalProxies, err := c.prepareDatabases(cf, tc, registry, uris, logger, server)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer closeLocalProxies()
+
+ for protocol, newServerFunc := range registry {
+ databases := allDatabases[protocol]
+ if len(databases) == 0 {
+ continue
+ }
+
+ srv, err := newServerFunc(cf.Context, &dbmcp.NewServerConfig{
+ Logger: logger,
+ RootServer: server,
+ Databases: databases,
+ })
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer srv.Close(cf.Context)
+ }
+
+ return trace.Wrap(server.ServeStdio(cf.Context, cf.Stdin(), cf.Stdout()))
+}
+
+// closeLocalProxyFunc function used to close local proxy listeners.
+type closeLocalProxyFunc func() error
+
+// prepareDatabases based on the available MCP servers, initialize the database
+// local proxy and generate the MCP database.
+func (c *mcpDBStartCommand) prepareDatabases(
+ cf *CLIConf,
+ tc *client.TeleportClient,
+ registry dbmcp.Registry,
+ uris []*mcp.ResourceURI,
+ logger *slog.Logger,
+ server *dbmcp.RootServer,
+) (map[string][]*dbmcp.Database, closeLocalProxyFunc, error) {
+ var (
+ ctx = cf.Context
+ dbsPerProtocol = make(map[string][]*dbmcp.Database)
+ closeFuncs []closeLocalProxyFunc
+ )
+
+ for _, uri := range uris {
+ serviceName := uri.GetDatabaseServiceName()
+ dbUser := uri.GetDatabaseUser()
+ dbName := uri.GetDatabaseName()
+
+ route := tlsca.RouteToDatabase{
+ ServiceName: serviceName,
+ Username: dbUser,
+ Database: dbName,
+ }
+
+ info, err := getDatabaseInfo(cf, tc, []tlsca.RouteToDatabase{route})
+ if err != nil {
+ logger.InfoContext(ctx, "failed to retrieve database information", "database", serviceName, "error", err)
+ continue
+ }
+
+ db, err := info.GetDatabase(ctx, tc)
+ if err != nil {
+ logger.InfoContext(ctx, "failed to load database information", "database", serviceName, "error", err)
+ continue
+ }
+
+ if !registry.IsSupported(db.GetProtocol()) {
+ logger.InfoContext(ctx, "database protocol unsupported, skipping it", "database", serviceName, "protocol", db.GetProtocol())
+ continue
+ }
+
+ route.Protocol = db.GetProtocol()
+ cc := client.NewDBCertChecker(tc, route, nil, client.WithTTL(tc.KeyTTL))
+ // This avoids having the middleware to refresh the certificate if there
+ // is a certificate available on disk.
+ cert, err := loadDBCertificate(tc, route.ServiceName)
+ if err == nil {
+ cc.SetCert(cert)
+ }
+
+ listener := listener.NewInMemoryListener()
+ lp, err := alpnproxy.NewLocalProxy(
+ makeBasicLocalProxyConfig(ctx, tc, listener, tc.InsecureSkipVerify),
+ alpnproxy.WithDatabaseProtocol(route.Protocol),
+ alpnproxy.WithMiddleware(cc),
+ alpnproxy.WithClusterCAsIfConnUpgrade(ctx, tc.RootClusterCACertPool),
+ )
+ if err != nil {
+ _ = listener.Close()
+ logger.ErrorContext(ctx, "failed to start local proxy for database, skipping it", "database", db.GetName(), "error", err)
+ continue
+ }
+ go func() {
+ defer lp.Close()
+ if err = lp.Start(ctx); err != nil {
+ logger.WarnContext(ctx, "failed to start local ALPN proxy", "error", err)
+ }
+ }()
+
+ mcpDB := &dbmcp.Database{
+ DB: db,
+ ClusterName: uri.GetClusterName(),
+ DatabaseUser: dbUser,
+ DatabaseName: dbName,
+ Addr: listener.Addr().String(),
+ ExternalErrorRetriever: cc,
+ // Since we're using in-memory listener we don't need to resolve the
+ // address.
+ LookupFunc: func(ctx context.Context, host string) (addrs []string, err error) {
+ return []string{listener.Addr().String()}, nil
+ },
+ DialContextFunc: listener.DialContext,
+ }
+ dbsPerProtocol[db.GetProtocol()] = append(dbsPerProtocol[db.GetProtocol()], mcpDB)
+ server.RegisterDatabase(mcpDB)
+ closeFuncs = append(closeFuncs, listener.Close)
+ }
+
+ return dbsPerProtocol, func() error {
+ var errs []error
+ for _, closeFunc := range closeFuncs {
+ errs = append(errs, closeFunc())
+ }
+
+ return trace.NewAggregate(errs...)
+ }, nil
+}
+
+var (
+ // defaultDBMCPRegistry is the default database access MCP servers registry.
+ defaultDBMCPRegistry = map[string]dbmcp.NewServerFunc{
+ defaults.ProtocolPostgres: pgmcp.NewServer,
+ }
+)
diff --git a/tool/tsh/common/mcp_db_test.go b/tool/tsh/common/mcp_db_test.go
new file mode 100644
index 0000000000000..123657d212841
--- /dev/null
+++ b/tool/tsh/common/mcp_db_test.go
@@ -0,0 +1,223 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package common
+
+import (
+ "context"
+ "io"
+ "testing"
+ "time"
+
+ mcpclient "github.com/mark3labs/mcp-go/client"
+ mcptransport "github.com/mark3labs/mcp-go/client/transport"
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/types"
+ dbmcp "github.com/gravitational/teleport/lib/client/db/mcp"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/service/servicecfg"
+ testserver "github.com/gravitational/teleport/tool/teleport/testenv"
+)
+
+func TestMCPDBCommand(t *testing.T) {
+ tmpHomePath := t.TempDir()
+ connector := mockConnector(t)
+ alice, err := types.NewUser("alice@example.com")
+ require.NoError(t, err)
+ alice.SetDatabaseUsers([]string{"postgres"})
+ alice.SetDatabaseNames([]string{"postgres"})
+ alice.SetRoles([]string{"access"})
+
+ authProcess := testserver.MakeTestServer(
+ t,
+ testserver.WithClusterName(t, "root"),
+ testserver.WithBootstrap(connector, alice),
+ testserver.WithConfig(func(cfg *servicecfg.Config) {
+ cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex)
+ cfg.Databases.Enabled = true
+ cfg.Databases.Databases = []servicecfg.Database{
+ {
+ Name: "postgres1",
+ Protocol: defaults.ProtocolPostgres,
+ URI: "external-pg:5432",
+ },
+ {
+ Name: "postgres2",
+ Protocol: defaults.ProtocolPostgres,
+ URI: "external-pg:5432",
+ },
+ {
+ Name: "mysql-local",
+ Protocol: defaults.ProtocolMySQL,
+ URI: "external-mysql:3306",
+ },
+ }
+ }),
+ )
+
+ authServer := authProcess.GetAuthServer()
+ require.NotNil(t, authServer)
+
+ proxyAddr, err := authProcess.ProxyWebAddr()
+ require.NoError(t, err)
+
+ err = Run(t.Context(), []string{
+ "login", "--insecure", "--debug", "--proxy", proxyAddr.String(),
+ }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, alice, connector.GetName()))
+ require.NoError(t, err)
+
+ stdin, writer := io.Pipe()
+ reader, stdout := io.Pipe()
+
+ ctx, cancel := context.WithCancel(t.Context())
+ defer cancel()
+
+ executionCh := make(chan error)
+ go func() {
+ executionCh <- Run(ctx, []string{
+ "mcp",
+ "db",
+ "start",
+ "teleport://clusters/root/databases/postgres1?dbUser=postgres&dbName=postgres",
+ "teleport://clusters/root/databases/postgres2?dbUser=postgres&dbName=postgres",
+ }, setHomePath(tmpHomePath), func(c *CLIConf) error {
+ c.overrideStdin = stdin
+ c.OverrideStdout = stdout
+ // MCP server logs are going to be discarded.
+ c.overrideStderr = io.Discard
+ c.databaseMCPRegistryOverride = map[string]dbmcp.NewServerFunc{
+ defaults.ProtocolPostgres: func(ctx context.Context, nsc *dbmcp.NewServerConfig) (dbmcp.Server, error) {
+ return &testDatabaseMCP{}, nil
+ },
+ }
+ return nil
+ })
+ }()
+
+ clt := mcpclient.NewClient(mcptransport.NewIO(reader, writer, nil /* logging */))
+ require.NoError(t, clt.Start(t.Context()))
+
+ req := mcp.InitializeRequest{}
+ req.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
+ req.Params.ClientInfo = mcp.Implementation{
+ Name: "test-client",
+ Version: "1.0.0",
+ }
+
+ require.EventuallyWithT(t, func(collect *assert.CollectT) {
+ _, err = clt.Initialize(t.Context(), req)
+ require.NoError(collect, err)
+ require.NoError(collect, clt.Ping(t.Context()))
+ }, time.Second, 100*time.Millisecond)
+
+ // Stop the MCP server command and wait until it is finshed.
+ cancel()
+ select {
+ case err := <-executionCh:
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(10 * time.Second):
+ require.Fail(t, "expected the execution to be completed")
+ }
+}
+
+func TestMCPDBCommandFailures(t *testing.T) {
+ tmpHomePath := t.TempDir()
+ connector := mockConnector(t)
+ alice, err := types.NewUser("alice@example.com")
+ require.NoError(t, err)
+ alice.SetDatabaseUsers([]string{"postgres"})
+ alice.SetDatabaseNames([]string{"postgres"})
+ alice.SetRoles([]string{"access"})
+ clusterName := "root"
+
+ authProcess := testserver.MakeTestServer(
+ t,
+ testserver.WithClusterName(t, clusterName),
+ testserver.WithBootstrap(connector, alice),
+ testserver.WithConfig(func(cfg *servicecfg.Config) {
+ cfg.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex)
+ cfg.Databases.Enabled = true
+ cfg.Databases.Databases = []servicecfg.Database{
+ {
+ Name: "postgres1",
+ Protocol: defaults.ProtocolPostgres,
+ URI: "external-pg:5432",
+ },
+ {
+ Name: "postgres2",
+ Protocol: defaults.ProtocolPostgres,
+ URI: "external-pg:5432",
+ },
+ {
+ Name: "mysql-local",
+ Protocol: defaults.ProtocolMySQL,
+ URI: "external-mysql:3306",
+ },
+ }
+ }),
+ )
+
+ authServer := authProcess.GetAuthServer()
+ require.NotNil(t, authServer)
+
+ proxyAddr, err := authProcess.ProxyWebAddr()
+ require.NoError(t, err)
+
+ withMockedMCPServers := func(c *CLIConf) error {
+ c.databaseMCPRegistryOverride = map[string]dbmcp.NewServerFunc{
+ defaults.ProtocolPostgres: func(ctx context.Context, nsc *dbmcp.NewServerConfig) (dbmcp.Server, error) {
+ return &testDatabaseMCP{}, nil
+ },
+ }
+ return nil
+ }
+
+ err = Run(t.Context(), []string{
+ "login", "--insecure", "--debug", "--proxy", proxyAddr.String(),
+ }, setHomePath(tmpHomePath), setMockSSOLogin(authServer, alice, connector.GetName()))
+ require.NoError(t, err)
+
+ t.Run("different clusters", func(t *testing.T) {
+ err := Run(t.Context(), []string{
+ "mcp",
+ "db",
+ "start",
+ "teleport://clusters/root/databases/postgres1?dbUser=postgres&dbName=postgres",
+ "teleport://clusters/other/databases/postgres2?dbUser=postgres&dbName=postgres",
+ }, setHomePath(tmpHomePath), withMockedMCPServers)
+ require.Error(t, err)
+ })
+
+ t.Run("duplicated databases", func(t *testing.T) {
+ err := Run(t.Context(), []string{
+ "mcp",
+ "db",
+ "start",
+ "teleport://clusters/root/databases/postgres1?dbUser=postgres&dbName=postgres",
+ "teleport://clusters/root/databases/postgres1?dbUser=readonly&dbName=postgres",
+ }, setHomePath(tmpHomePath), withMockedMCPServers)
+ require.Error(t, err)
+ })
+}
+
+// testDatabaseMCP is a noop database MCP server.
+type testDatabaseMCP struct{}
+
+func (s *testDatabaseMCP) Close(_ context.Context) error { return nil }
diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go
index 550811255519b..c6612cd7f1ba6 100644
--- a/tool/tsh/common/tsh.go
+++ b/tool/tsh/common/tsh.go
@@ -79,6 +79,7 @@ import (
benchmarkdb "github.com/gravitational/teleport/lib/benchmark/db"
"github.com/gravitational/teleport/lib/client"
dbprofile "github.com/gravitational/teleport/lib/client/db"
+ dbmcp "github.com/gravitational/teleport/lib/client/db/mcp"
"github.com/gravitational/teleport/lib/client/identityfile"
"github.com/gravitational/teleport/lib/defaults"
dtauthn "github.com/gravitational/teleport/lib/devicetrust/authn"
@@ -618,6 +619,10 @@ type CLIConf struct {
// atomic here is overkill as the CLIConf is generally consumed sequentially. However, occasionally
// we need concurrency safety, such as for [forEachProfileParallel].
clientStoreSet int32
+
+ // databaseMCPRegistryOverride overrides database access MCP servers
+ // registry. used in tests.
+ databaseMCPRegistryOverride dbmcp.Registry
}
// Stdout returns the stdout writer.
@@ -785,7 +790,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
// run early to enable debug logging if env var is set.
// this makes it possible to debug early startup functionality, particularly command aliases.
- if err := initLogger(&cf, parseLoggingOptsFromEnv()); err != nil {
+ if _, err := initLogger(&cf, utils.LoggingForCLI, parseLoggingOptsFromEnv()); err != nil {
printInitLoggerError(err)
}
@@ -1340,6 +1345,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
gitCmd := newGitCommands(app)
pivCmd := newPIVCommands(app)
+ mcpCmd := newMCPCommands(app)
+
if runtime.GOOS == constants.WindowsOS {
bench.Hidden()
}
@@ -1430,7 +1437,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
// Enable debug logging if requested by --debug.
// If TELEPORT_DEBUG was set and --debug/--no-debug was not passed, debug logs were already
// enabled by a prior call to initLogger.
- if err := initLogger(&cf, parseLoggingOptsFromEnvAndArgv(&cf)); err != nil {
+ if _, err := initLogger(&cf, utils.LoggingForCLI, parseLoggingOptsFromEnvAndArgv(&cf)); err != nil {
printInitLoggerError(err)
}
@@ -1745,6 +1752,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
err = gitCmd.clone.run(&cf)
case pivCmd.agent.FullCommand():
err = pivCmd.agent.run(&cf)
+ case mcpCmd.dbStart.FullCommand():
+ err = mcpCmd.dbStart.run(&cf)
default:
// Handle commands that might not be available.
switch {