Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5585,3 +5585,42 @@ func (tc *TeleportClient) issueMCPCertWithMFA(ctx context.Context, mcpServer typ
cert, err := keyRing.AppTLSCert(mcpServer.GetName())
return cert, trace.Wrap(err)
}

// DialDatabase makes a remote connection to the database.
//
// TODO(gabrielcorado): support acccess requests connections.
func (tc *TeleportClient) DialDatabase(ctx context.Context, route proto.RouteToDatabase) (net.Conn, error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/DialDatabase",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
attribute.String("db", route.GetServiceName()),
attribute.String("protocol", route.GetProtocol()),
),
)
defer span.End()

dbCertParams := ReissueParams{
RouteToCluster: tc.SiteName,
RouteToDatabase: route,
TTL: tc.KeyTTL,
}

alpnProtocol, err := alpncommon.ToALPNProtocol(route.GetProtocol())
if err != nil {
return nil, trace.Wrap(err)
}

keyRing, err := tc.IssueUserCertsWithMFA(ctx, dbCertParams)
if err != nil {
return nil, trace.Wrap(err)
}

cert, err := keyRing.DBTLSCert(route.GetServiceName())
if err != nil {
return nil, trace.Wrap(err)
}

return tc.DialALPN(ctx, cert, alpnProtocol)
}
19 changes: 3 additions & 16 deletions lib/client/db/mcp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,14 @@ import (

apiclient "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/lib/client/mcp"
"github.com/gravitational/teleport/lib/utils"
)

// ExtenralErrorRetriever returns an external error that might have happened.
//
// MCP servers don't have knowledge of other processes that might fail during
// their execution, such as authentication failures. This provider can be used
// to give them the necessary context to provide more accurate user messages.
type ExternalErrorRetriever interface {
// RetrieveError retrieves the error if any.
RetrieveError() error
}

// FormatErrorMessage formats the database MCP error messages.
// format.
func FormatErrorMessage(retreiver ExternalErrorRetriever, err error) error {
if retreiver != nil {
err = trace.NewAggregate(retreiver.RetrieveError(), err)
}

func FormatErrorMessage(err error) error {
switch {
case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired):
case errors.Is(err, apiclient.ErrClientCredentialsHaveExpired) || utils.IsCertExpiredError(err):
return trace.BadParameter(ReloginRequiredErrorMessage)
case strings.Contains(err.Error(), "connection reset by peer") || errors.Is(err, io.ErrClosedPipe):
return trace.BadParameter(LocalProxyConnectionErrorMessage)
Expand Down
6 changes: 0 additions & 6 deletions lib/client/db/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,10 @@ type Database struct {
DB types.Database
// ClusterName is the cluster name where the database is located.
ClusterName string
// Addr is the address the MCP server used to create a new database
// connection.
Addr string
// DatabaseUser is the database username used on the connections.
DatabaseUser string
// DatabaseName is the database name used on the connections.
DatabaseName string
// ExternalErrorRetriever used to retrieve any external error that might
// have happened while connecting/communicating with the database.
ExternalErrorRetriever ExternalErrorRetriever
// LookupFunc is the lookup function to resolve database address.
LookupFunc LookupFunc
// DialContextFunc is the dial function used to connect to the database.
Expand Down
2 changes: 1 addition & 1 deletion lib/client/db/mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (s *RootServer) ListDatabases(ctx context.Context, request mcp.CallToolRequ
contents, err := encodeDatabaseResource(db)
if err != nil {
s.logger.ErrorContext(ctx, "error while list databases", "error", err)
return mcp.NewToolResultError(FormatErrorMessage(nil, err).Error()), nil
return mcp.NewToolResultError(FormatErrorMessage(err).Error()), nil
}
res = append(res, mcp.EmbeddedResource{Type: "resource", Resource: contents})
}
Expand Down
1 change: 0 additions & 1 deletion lib/client/db/mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ func buildDatabase(t *testing.T, name string) *Database {
return &Database{
DB: db,
ClusterName: "root",
Addr: "localhost:5555",
}
}

Expand Down
18 changes: 10 additions & 8 deletions lib/client/db/postgres/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,39 +112,39 @@ type RunQueryResult struct {
func (s *Server) RunQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
uri, err := request.RequireString(queryToolDatabaseParam)
if err != nil {
return s.wrapErrorResult(ctx, nil, trace.Wrap(err))
return s.wrapErrorResult(ctx, trace.Wrap(err))
}

sql, err := request.RequireString(queryToolQueryParam)
if err != nil {
return s.wrapErrorResult(ctx, nil, trace.Wrap(err))
return s.wrapErrorResult(ctx, trace.Wrap(err))
}

db, err := s.getDatabase(uri)
if err != nil {
return s.wrapErrorResult(ctx, nil, err)
return s.wrapErrorResult(ctx, err)
}

// TODO(gabrielcorado): ensure the connection used is consistent for the
// session, making most of its queries to be present in a single audit
// session/recording.
rows, err := db.pool.Query(ctx, sql)
if err != nil {
return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err)
return s.wrapErrorResult(ctx, err)
}

// Returned rows are being closed by this function.
result, err := buildQueryResult(rows)
if err != nil {
return s.wrapErrorResult(ctx, db.ExternalErrorRetriever, err)
return s.wrapErrorResult(ctx, err)
}

return mcp.NewToolResultText(result), nil
}

func (s *Server) wrapErrorResult(ctx context.Context, externalRetriever dbmcp.ExternalErrorRetriever, toolErr error) (*mcp.CallToolResult, error) {
func (s *Server) wrapErrorResult(ctx context.Context, toolErr error) (*mcp.CallToolResult, error) {
s.logger.ErrorContext(ctx, "error while querying database", "error", toolErr)
out, err := json.Marshal(RunQueryResult{ErrorMessage: dbmcp.FormatErrorMessage(externalRetriever, toolErr).Error()})
out, err := json.Marshal(RunQueryResult{ErrorMessage: dbmcp.FormatErrorMessage(toolErr).Error()})
return mcp.NewToolResultError(string(out)), trace.Wrap(err)
}

Expand Down Expand Up @@ -204,7 +204,9 @@ func (s *Server) getDatabase(uri string) (*database, error) {
}

func buildConnConfig(db *dbmcp.Database) (*pgxpool.Config, error) {
config, err := pgxpool.ParseConfig("postgres://" + db.Addr)
// No need to provide a valid address here as the Lookup and DialContext
// will handle the connection.
config, err := pgxpool.ParseConfig("postgres://")
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
31 changes: 11 additions & 20 deletions lib/client/db/postgres/mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"encoding/json"
"log/slog"
"net"
"testing"

"github.com/jackc/pgx/v5"
Expand Down Expand Up @@ -88,10 +89,9 @@ func TestFormatErrors(t *testing.T) {
dbURI := clientmcp.NewDatabaseResourceURI("root", dbName).WithoutParams().String()

for name, tc := range map[string]struct {
databaseURI string
databases []*dbmcp.Database
externalErrorRetriever dbmcp.ExternalErrorRetriever
expectErrorMessage require.ValueAssertionFunc
databaseURI string
databases []*dbmcp.Database
expectErrorMessage require.ValueAssertionFunc
}{
"database not found": {
databaseURI: "teleport://clusters/root/databases/not-found",
Expand All @@ -113,7 +113,6 @@ func TestFormatErrors(t *testing.T) {
ClusterName: "root",
DatabaseUser: "postgres",
DatabaseName: "postgres",
Addr: listener.Addr().String(),
LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) {
return []string{"memory"}, nil
},
Expand All @@ -128,16 +127,16 @@ func TestFormatErrors(t *testing.T) {
databaseURI: dbURI,
databases: []*dbmcp.Database{
&dbmcp.Database{
DB: db,
ClusterName: "root",
DatabaseUser: "postgres",
DatabaseName: "postgres",
Addr: listener.Addr().String(),
ExternalErrorRetriever: &mockErrorRetriever{err: client.ErrClientCredentialsHaveExpired},
DB: db,
ClusterName: "root",
DatabaseUser: "postgres",
DatabaseName: "postgres",
LookupFunc: func(_ context.Context, _ string) (addrs []string, err error) {
return []string{"memory"}, nil
},
DialContextFunc: listener.DialContext,
DialContextFunc: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, client.ErrClientCredentialsHaveExpired
},
},
},
expectErrorMessage: func(tt require.TestingT, i1 any, i2 ...any) {
Expand Down Expand Up @@ -224,11 +223,3 @@ func (mr *mockRows) CommandTag() pgconn.CommandTag {
}

func (mr *mockRows) Close() {}

type mockErrorRetriever struct {
err error
}

func (mr *mockErrorRetriever) RetrieveError() error {
return mr.err
}
20 changes: 0 additions & 20 deletions lib/client/local_proxy_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ type CertChecker struct {

cert tls.Certificate
certMu sync.Mutex

err error
errMu sync.Mutex
}

var _ alpnproxy.LocalProxyMiddleware = (*CertChecker)(nil)
Expand Down Expand Up @@ -149,10 +146,6 @@ func (c *CertChecker) GetOrIssueCert(ctx context.Context) (cert tls.Certificate,
c.certMu.Lock()
defer c.certMu.Unlock()

defer func() {
c.setError(err)
}()

if err := c.checkCert(); err == nil {
return c.cert, nil
}
Expand All @@ -177,13 +170,6 @@ func (c *CertChecker) GetOrIssueCert(ctx context.Context) (cert tls.Certificate,
return c.cert, nil
}

// RetrieveError retrieves the happened on while retrieving certificates.
func (c *CertChecker) RetrieveError() error {
c.errMu.Lock()
defer c.errMu.Unlock()
return c.err
}

func (c *CertChecker) checkCert() error {
leaf, err := utils.TLSCertLeaf(c.cert)
if err != nil {
Expand All @@ -198,12 +184,6 @@ func (c *CertChecker) checkCert() error {
return trace.Wrap(c.certIssuer.CheckCert(leaf))
}

func (c *CertChecker) setError(err error) {
c.errMu.Lock()
defer c.errMu.Unlock()
c.err = err
}

// CertIssuer checks and issues certs.
type CertIssuer interface {
// CheckCert checks that an existing certificate is valid.
Expand Down
6 changes: 0 additions & 6 deletions lib/client/local_proxy_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ func TestCertChecker(t *testing.T) {
// certChecker should issue a new cert on first request.
cert, err := certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
require.NoError(t, certChecker.RetrieveError())

// subsequent calls should return the same cert.
sameCert, err := certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
require.NoError(t, certChecker.RetrieveError())
require.Equal(t, cert, sameCert)

// If the current cert expires it should be reissued.
Expand All @@ -58,7 +56,6 @@ func TestCertChecker(t *testing.T) {

cert, err = certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
require.NoError(t, certChecker.RetrieveError())
require.NotEqual(t, cert, expiredCert)

// If the current cert fails certIssuer checks, a new one should be issued.
Expand All @@ -67,20 +64,17 @@ func TestCertChecker(t *testing.T) {

cert, err = certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
require.NoError(t, certChecker.RetrieveError())
require.NotEqual(t, cert, badCert)

// If issuing a new cert fails, an error is returned.
certIssuer.issueErr = trace.BadParameter("failed to issue cert")
_, err = certChecker.GetOrIssueCert(ctx)
require.ErrorIs(t, err, certIssuer.issueErr, "expected error %v but got %v", certIssuer.issueErr, err)
require.ErrorIs(t, certChecker.RetrieveError(), err, "expected retrieve error to be the same get error but got: %v", certChecker.RetrieveError())

// If the problem is solved, the error is clean up.
certIssuer.issueErr = nil
_, err = certChecker.GetOrIssueCert(ctx)
require.NoError(t, err)
require.NoError(t, certChecker.RetrieveError())
}

func TestLocalCertGenerator(t *testing.T) {
Expand Down
Loading
Loading