diff --git a/lib/client/api.go b/lib/client/api.go index b542015a79152..4d2059d27b433 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -5585,3 +5585,42 @@ func (tc *TeleportClient) issueMCPCertWithMFA(ctx context.Context, mcpServer typ cert, err := keyRing.AppTLSCert(mcpServer.GetName()) return cert, trace.Wrap(err) } + +// DialDatabase makes a remote connection to the database. +// +// TODO(gabrielcorado): support acccess requests connections. +func (tc *TeleportClient) DialDatabase(ctx context.Context, route proto.RouteToDatabase) (net.Conn, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/DialDatabase", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("db", route.GetServiceName()), + attribute.String("protocol", route.GetProtocol()), + ), + ) + defer span.End() + + dbCertParams := ReissueParams{ + RouteToCluster: tc.SiteName, + RouteToDatabase: route, + TTL: tc.KeyTTL, + } + + alpnProtocol, err := alpncommon.ToALPNProtocol(route.GetProtocol()) + if err != nil { + return nil, trace.Wrap(err) + } + + keyRing, err := tc.IssueUserCertsWithMFA(ctx, dbCertParams) + if err != nil { + return nil, trace.Wrap(err) + } + + cert, err := keyRing.DBTLSCert(route.GetServiceName()) + if err != nil { + return nil, trace.Wrap(err) + } + + return tc.DialALPN(ctx, cert, alpnProtocol) +} diff --git a/lib/client/db/mcp/errors.go b/lib/client/db/mcp/errors.go index 8ab0112e9332a..19fdc90dbb2ac 100644 --- a/lib/client/db/mcp/errors.go +++ b/lib/client/db/mcp/errors.go @@ -25,27 +25,14 @@ import ( apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/lib/client/mcp" + "github.com/gravitational/teleport/lib/utils" ) -// ExtenralErrorRetriever returns an external error that might have happened. -// -// MCP servers don't have knowledge of other processes that might fail during -// their execution, such as authentication failures. This provider can be used -// to give them the necessary context to provide more accurate user messages. -type ExternalErrorRetriever interface { - // RetrieveError retrieves the error if any. - RetrieveError() error -} - // FormatErrorMessage formats the database MCP error messages. // format. -func FormatErrorMessage(retreiver ExternalErrorRetriever, err error) error { - if retreiver != nil { - err = trace.NewAggregate(retreiver.RetrieveError(), err) - } - +func FormatErrorMessage(err error) error { switch { - case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired): + case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired) || utils.IsCertExpiredError(err): return trace.BadParameter(ReloginRequiredErrorMessage) case strings.Contains(err.Error(), "connection reset by peer") || errors.Is(err, io.ErrClosedPipe): return trace.BadParameter(LocalProxyConnectionErrorMessage) diff --git a/lib/client/db/mcp/mcp.go b/lib/client/db/mcp/mcp.go index 43baf36144747..fdb514860bf5a 100644 --- a/lib/client/db/mcp/mcp.go +++ b/lib/client/db/mcp/mcp.go @@ -66,16 +66,10 @@ type Database struct { DB types.Database // ClusterName is the cluster name where the database is located. ClusterName string - // Addr is the address the MCP server used to create a new database - // connection. - Addr string // DatabaseUser is the database username used on the connections. DatabaseUser string // DatabaseName is the database name used on the connections. DatabaseName string - // ExternalErrorRetriever used to retrieve any external error that might - // have happened while connecting/communicating with the database. - ExternalErrorRetriever ExternalErrorRetriever // LookupFunc is the lookup function to resolve database address. LookupFunc LookupFunc // DialContextFunc is the dial function used to connect to the database. diff --git a/lib/client/db/mcp/server.go b/lib/client/db/mcp/server.go index 19924b9a289c1..e839848a94b9e 100644 --- a/lib/client/db/mcp/server.go +++ b/lib/client/db/mcp/server.go @@ -74,7 +74,7 @@ func (s *RootServer) ListDatabases(ctx context.Context, request mcp.CallToolRequ contents, err := encodeDatabaseResource(db) if err != nil { s.logger.ErrorContext(ctx, "error while list databases", "error", err) - return mcp.NewToolResultError(FormatErrorMessage(nil, err).Error()), nil + return mcp.NewToolResultError(FormatErrorMessage(err).Error()), nil } res = append(res, mcp.EmbeddedResource{Type: "resource", Resource: contents}) } diff --git a/lib/client/db/mcp/server_test.go b/lib/client/db/mcp/server_test.go index f581f7462b1d1..0a207bf24037f 100644 --- a/lib/client/db/mcp/server_test.go +++ b/lib/client/db/mcp/server_test.go @@ -150,7 +150,6 @@ func buildDatabase(t *testing.T, name string) *Database { return &Database{ DB: db, ClusterName: "root", - Addr: "localhost:5555", } } diff --git a/lib/client/db/postgres/mcp/mcp.go b/lib/client/db/postgres/mcp/mcp.go index 1624d11ba4a76..4e2004df9eb7f 100644 --- a/lib/client/db/postgres/mcp/mcp.go +++ b/lib/client/db/postgres/mcp/mcp.go @@ -112,17 +112,17 @@ type RunQueryResult struct { func (s *Server) RunQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { uri, err := request.RequireString(queryToolDatabaseParam) if err != nil { - return s.wrapErrorResult(ctx, nil, trace.Wrap(err)) + return s.wrapErrorResult(ctx, trace.Wrap(err)) } sql, err := request.RequireString(queryToolQueryParam) if err != nil { - return s.wrapErrorResult(ctx, nil, trace.Wrap(err)) + return s.wrapErrorResult(ctx, trace.Wrap(err)) } db, err := s.getDatabase(uri) if err != nil { - return s.wrapErrorResult(ctx, nil, err) + return s.wrapErrorResult(ctx, err) } // TODO(gabrielcorado): ensure the connection used is consistent for the @@ -130,21 +130,21 @@ func (s *Server) RunQuery(ctx context.Context, request mcp.CallToolRequest) (*mc // session/recording. rows, err := db.pool.Query(ctx, sql) if err != nil { - return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err) + return s.wrapErrorResult(ctx, err) } // Returned rows are being closed by this function. result, err := buildQueryResult(rows) if err != nil { - return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err) + return s.wrapErrorResult(ctx, err) } return mcp.NewToolResultText(result), nil } -func (s *Server) wrapErrorResult(ctx context.Context, externalRetriever dbmcp.ExternalErrorRetriever, toolErr error) (*mcp.CallToolResult, error) { +func (s *Server) wrapErrorResult(ctx context.Context, toolErr error) (*mcp.CallToolResult, error) { s.logger.ErrorContext(ctx, "error while querying database", "error", toolErr) - out, err := json.Marshal(RunQueryResult{ErrorMessage: dbmcp.FormatErrorMessage(externalRetriever, toolErr).Error()}) + out, err := json.Marshal(RunQueryResult{ErrorMessage: dbmcp.FormatErrorMessage(toolErr).Error()}) return mcp.NewToolResultError(string(out)), trace.Wrap(err) } @@ -204,7 +204,9 @@ func (s *Server) getDatabase(uri string) (*database, error) { } func buildConnConfig(db *dbmcp.Database) (*pgxpool.Config, error) { - config, err := pgxpool.ParseConfig("postgres://" + db.Addr) + // No need to provide a valid address here as the Lookup and DialContext + // will handle the connection. + config, err := pgxpool.ParseConfig("postgres://") if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/client/db/postgres/mcp/mcp_test.go b/lib/client/db/postgres/mcp/mcp_test.go index db959c2ace1fb..7e018d6ef5917 100644 --- a/lib/client/db/postgres/mcp/mcp_test.go +++ b/lib/client/db/postgres/mcp/mcp_test.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "log/slog" + "net" "testing" "github.com/jackc/pgx/v5" @@ -88,10 +89,9 @@ func TestFormatErrors(t *testing.T) { dbURI := clientmcp.NewDatabaseResourceURI("root", dbName).WithoutParams().String() for name, tc := range map[string]struct { - databaseURI string - databases []*dbmcp.Database - externalErrorRetriever dbmcp.ExternalErrorRetriever - expectErrorMessage require.ValueAssertionFunc + databaseURI string + databases []*dbmcp.Database + expectErrorMessage require.ValueAssertionFunc }{ "database not found": { databaseURI: "teleport://clusters/root/databases/not-found", @@ -113,7 +113,6 @@ func TestFormatErrors(t *testing.T) { ClusterName: "root", DatabaseUser: "postgres", DatabaseName: "postgres", - Addr: listener.Addr().String(), LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) { return []string{"memory"}, nil }, @@ -128,16 +127,16 @@ func TestFormatErrors(t *testing.T) { databaseURI: dbURI, databases: []*dbmcp.Database{ &dbmcp.Database{ - DB: db, - ClusterName: "root", - DatabaseUser: "postgres", - DatabaseName: "postgres", - Addr: listener.Addr().String(), - ExternalErrorRetriever: &mockErrorRetriever{err: client.ErrClientCredentialsHaveExpired}, + DB: db, + ClusterName: "root", + DatabaseUser: "postgres", + DatabaseName: "postgres", LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) { return []string{"memory"}, nil }, - DialContextFunc: listener.DialContext, + DialContextFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, client.ErrClientCredentialsHaveExpired + }, }, }, expectErrorMessage: func(tt require.TestingT, i1 any, i2 ...any) { @@ -224,11 +223,3 @@ func (mr *mockRows) CommandTag() pgconn.CommandTag { } func (mr *mockRows) Close() {} - -type mockErrorRetriever struct { - err error -} - -func (mr *mockErrorRetriever) RetrieveError() error { - return mr.err -} diff --git a/lib/client/local_proxy_middleware.go b/lib/client/local_proxy_middleware.go index 9f4fd19670d67..82341b50cf49c 100644 --- a/lib/client/local_proxy_middleware.go +++ b/lib/client/local_proxy_middleware.go @@ -52,9 +52,6 @@ type CertChecker struct { cert tls.Certificate certMu sync.Mutex - - err error - errMu sync.Mutex } var _ alpnproxy.LocalProxyMiddleware = (*CertChecker)(nil) @@ -149,10 +146,6 @@ func (c *CertChecker) GetOrIssueCert(ctx context.Context) (cert tls.Certificate, c.certMu.Lock() defer c.certMu.Unlock() - defer func() { - c.setError(err) - }() - if err := c.checkCert(); err == nil { return c.cert, nil } @@ -177,13 +170,6 @@ func (c *CertChecker) GetOrIssueCert(ctx context.Context) (cert tls.Certificate, return c.cert, nil } -// RetrieveError retrieves the happened on while retrieving certificates. -func (c *CertChecker) RetrieveError() error { - c.errMu.Lock() - defer c.errMu.Unlock() - return c.err -} - func (c *CertChecker) checkCert() error { leaf, err := utils.TLSCertLeaf(c.cert) if err != nil { @@ -198,12 +184,6 @@ func (c *CertChecker) checkCert() error { return trace.Wrap(c.certIssuer.CheckCert(leaf)) } -func (c *CertChecker) setError(err error) { - c.errMu.Lock() - defer c.errMu.Unlock() - c.err = err -} - // CertIssuer checks and issues certs. type CertIssuer interface { // CheckCert checks that an existing certificate is valid. diff --git a/lib/client/local_proxy_middleware_test.go b/lib/client/local_proxy_middleware_test.go index 48d99c859c39d..2a629609f7040 100644 --- a/lib/client/local_proxy_middleware_test.go +++ b/lib/client/local_proxy_middleware_test.go @@ -44,12 +44,10 @@ func TestCertChecker(t *testing.T) { // certChecker should issue a new cert on first request. cert, err := certChecker.GetOrIssueCert(ctx) require.NoError(t, err) - require.NoError(t, certChecker.RetrieveError()) // subsequent calls should return the same cert. sameCert, err := certChecker.GetOrIssueCert(ctx) require.NoError(t, err) - require.NoError(t, certChecker.RetrieveError()) require.Equal(t, cert, sameCert) // If the current cert expires it should be reissued. @@ -58,7 +56,6 @@ func TestCertChecker(t *testing.T) { cert, err = certChecker.GetOrIssueCert(ctx) require.NoError(t, err) - require.NoError(t, certChecker.RetrieveError()) require.NotEqual(t, cert, expiredCert) // If the current cert fails certIssuer checks, a new one should be issued. @@ -67,20 +64,17 @@ func TestCertChecker(t *testing.T) { cert, err = certChecker.GetOrIssueCert(ctx) require.NoError(t, err) - require.NoError(t, certChecker.RetrieveError()) require.NotEqual(t, cert, badCert) // If issuing a new cert fails, an error is returned. certIssuer.issueErr = trace.BadParameter("failed to issue cert") _, err = certChecker.GetOrIssueCert(ctx) require.ErrorIs(t, err, certIssuer.issueErr, "expected error %v but got %v", certIssuer.issueErr, err) - require.ErrorIs(t, certChecker.RetrieveError(), err, "expected retrieve error to be the same get error but got: %v", certChecker.RetrieveError()) // If the problem is solved, the error is clean up. certIssuer.issueErr = nil _, err = certChecker.GetOrIssueCert(ctx) require.NoError(t, err) - require.NoError(t, certChecker.RetrieveError()) } func TestLocalCertGenerator(t *testing.T) { diff --git a/tool/tsh/common/mcp_db.go b/tool/tsh/common/mcp_db.go index a47bf0eeaaad3..aed7ba469d2cf 100644 --- a/tool/tsh/common/mcp_db.go +++ b/tool/tsh/common/mcp_db.go @@ -21,6 +21,7 @@ import ( "fmt" "log/slog" "maps" + "net" "text/template" "github.com/alecthomas/kingpin/v2" @@ -35,10 +36,8 @@ import ( "github.com/gravitational/teleport/lib/client/mcp" "github.com/gravitational/teleport/lib/client/mcp/claude" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/utils/listener" ) // mcpDBStartCommand implements `tsh mcp db start` command. @@ -105,11 +104,10 @@ func (c *mcpDBStartCommand) run() error { } server := dbmcp.NewRootServer(logger) - allDatabases, closeLocalProxies, err := c.prepareDatabases(c.cf, tc, registry, uris, logger, server) + allDatabases, err := c.prepareDatabases(c.cf, tc, registry, uris, logger, server) if err != nil { return trace.Wrap(err) } - defer closeLocalProxies() for protocol, newServerFunc := range registry { databases := allDatabases[protocol] @@ -131,9 +129,6 @@ func (c *mcpDBStartCommand) run() error { return trace.Wrap(server.ServeStdio(c.cf.Context, c.cf.Stdin(), c.cf.Stdout())) } -// closeLocalProxyFunc function used to close local proxy listeners. -type closeLocalProxyFunc func() error - // prepareDatabases based on the available MCP servers, initialize the database // local proxy and generate the MCP database. func (c *mcpDBStartCommand) prepareDatabases( @@ -143,11 +138,10 @@ func (c *mcpDBStartCommand) prepareDatabases( uris []*mcp.ResourceURI, logger *slog.Logger, server *dbmcp.RootServer, -) (map[string][]*dbmcp.Database, closeLocalProxyFunc, error) { +) (map[string][]*dbmcp.Database, error) { var ( ctx = cf.Context dbsPerProtocol = make(map[string][]*dbmcp.Database) - closeFuncs []closeLocalProxyFunc ) for _, uri := range uris { @@ -178,61 +172,31 @@ func (c *mcpDBStartCommand) prepareDatabases( continue } - route.Protocol = db.GetProtocol() - cc := client.NewDBCertChecker(tc, route, nil, client.WithTTL(tc.KeyTTL)) - // This avoids having the middleware to refresh the certificate if there - // is a certificate available on disk. - cert, err := loadDBCertificate(tc, route.ServiceName) - if err == nil { - cc.SetCert(cert) - } - - listener := listener.NewInMemoryListener() - lp, err := alpnproxy.NewLocalProxy( - makeBasicLocalProxyConfig(ctx, tc, listener, tc.InsecureSkipVerify), - alpnproxy.WithDatabaseProtocol(route.Protocol), - alpnproxy.WithMiddleware(cc), - alpnproxy.WithClusterCAsIfConnUpgrade(ctx, tc.RootClusterCACertPool), - ) - if err != nil { - _ = listener.Close() - logger.ErrorContext(ctx, "failed to start local proxy for database, skipping it", "database", db.GetName(), "error", err) - continue - } - go func() { - defer lp.Close() - if err = lp.Start(ctx); err != nil { - logger.WarnContext(ctx, "failed to start local ALPN proxy", "error", err) - } - }() - mcpDB := &dbmcp.Database{ - DB: db, - ClusterName: uri.GetClusterName(), - DatabaseUser: dbUser, - DatabaseName: dbName, - Addr: listener.Addr().String(), - ExternalErrorRetriever: cc, - // Since we're using in-memory listener we don't need to resolve the - // address. + DB: db, + ClusterName: uri.GetClusterName(), + DatabaseUser: dbUser, + DatabaseName: dbName, + // Connections are always handled by the TeleportClient, so here we + // just need to return a placeholder. LookupFunc: func(ctx context.Context, host string) (addrs []string, err error) { - return []string{listener.Addr().String()}, nil + return []string{host}, nil + }, + DialContextFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := tc.DialDatabase(ctx, proto.RouteToDatabase{ + ServiceName: db.GetName(), + Protocol: db.GetProtocol(), + Username: dbUser, + Database: dbName, + }) + return conn, trace.Wrap(err) }, - DialContextFunc: listener.DialContext, } dbsPerProtocol[db.GetProtocol()] = append(dbsPerProtocol[db.GetProtocol()], mcpDB) server.RegisterDatabase(mcpDB) - closeFuncs = append(closeFuncs, listener.Close) } - return dbsPerProtocol, func() error { - var errs []error - for _, closeFunc := range closeFuncs { - errs = append(errs, closeFunc()) - } - - return trace.NewAggregate(errs...) - }, nil + return dbsPerProtocol, nil } // databasesGetter is the interface used to retrieve available