From 7ff3609390b9da7fbe2e0294c108267e98839783 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Mon, 27 Oct 2025 17:06:26 -0700 Subject: [PATCH] Fix MongoDB connection leak This fixes a topology monitoring connection leak. The leak was triggered by an error during server selection or the initial auth handshake with the MongoDB server. --- integration/db/db_integration_test.go | 53 +++++++++++++++++++-------- integration/db/fixture.go | 13 ++++++- lib/srv/db/mongodb/connect.go | 31 ++++++++++++---- lib/srv/db/mongodb/protocol/errors.go | 25 ++++++++----- lib/srv/db/mongodb/test.go | 37 ++++++++++++++----- 5 files changed, 114 insertions(+), 45 deletions(-) diff --git a/integration/db/db_integration_test.go b/integration/db/db_integration_test.go index 2802c6d4b3615..bee9bed7ddc29 100644 --- a/integration/db/db_integration_test.go +++ b/integration/db/db_integration_test.go @@ -31,6 +31,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/entitlements" @@ -357,9 +358,8 @@ func (p *DatabasePack) testMongoRootCluster(t *testing.T) { // testMongoConnectionCount tests if mongo service releases // resource after a mongo client disconnect. func (p *DatabasePack) testMongoConnectionCount(t *testing.T) { - connectMongoClient := func(t *testing.T) (serverConnectionCount int32) { - // Connect to the database service in root cluster. - client, err := mongodb.MakeTestClient(context.Background(), common.TestClientConfig{ + makeClient := func(t *testing.T, dbUser string) (*mongo.Client, error) { + return mongodb.MakeTestClient(t.Context(), common.TestClientConfig{ AuthClient: p.Root.Cluster.GetSiteAPI(p.Root.Cluster.Secrets.SiteName), AuthServer: p.Root.Cluster.Process.GetAuthServer(), Address: p.Root.Cluster.Web, @@ -368,27 +368,31 @@ func (p *DatabasePack) testMongoConnectionCount(t *testing.T) { RouteToDatabase: tlsca.RouteToDatabase{ ServiceName: p.Root.MongoService.Name, Protocol: p.Root.MongoService.Protocol, - Username: "admin", + Username: dbUser, }, }) + } + + connectMongoClient := func(t *testing.T, dbUser string) (serverConnectionCount int32) { + client, err := makeClient(t, dbUser) require.NoError(t, err) // Execute a query. - _, err = client.Database("test").Collection("test").Find(context.Background(), bson.M{}) + _, err = client.Database("test").Collection("test").Find(t.Context(), bson.M{}) require.NoError(t, err) // Get a server connection count before disconnect. serverConnectionCount = p.Root.mongo.GetActiveConnectionsCount() // Disconnect. - err = client.Disconnect(context.Background()) + err = client.Disconnect(t.Context()) require.NoError(t, err) return serverConnectionCount } // Get connection count while the first client is connected. - initialConnectionCount := connectMongoClient(t) + initialConnectionCount := connectMongoClient(t, "admin") // Check if active connections count is not growing over time when new // clients connect to the mongo server. @@ -396,22 +400,41 @@ func (p *DatabasePack) testMongoConnectionCount(t *testing.T) { for i := 0; i < clientCount; i++ { // Note that connection count per client fluctuates between 6 and 9. // Use InDelta to avoid flaky test. - require.InDelta(t, initialConnectionCount, connectMongoClient(t), 3) + got := connectMongoClient(t, "admin") + require.InDelta(t, initialConnectionCount, got, 3) + } + + client, err := makeClient(t, "nonexistent") + if !assert.Error(t, err) { + require.NoError(t, client.Disconnect(t.Context())) + return } + require.ErrorContains(t, err, "does not exist") // Wait until the server reports no more connections. This usually happens // really quick but wait a little longer just in case. - waitUntilNoConnections := func() bool { - return p.Root.mongo.GetActiveConnectionsCount() == 0 - } - require.Eventually(t, waitUntilNoConnections, 5*time.Second, 100*time.Millisecond) + var activeConns int32 + require.Eventually(t, func() bool { + conns := p.Root.mongo.GetActiveConnectionsCount() + if conns != activeConns { + t.Logf("active connections to MongoDB changed from %d to %d", activeConns, conns) + activeConns = conns + } + return activeConns == 0 + }, 5*time.Second, 100*time.Millisecond) + + start := time.Now() + require.Never(t, func() bool { + return p.Root.mongo.GetActiveConnectionsCount() > 0 + }, time.Second*5, time.Millisecond*100, "no connections should be left open. Found after %v", time.Since(start)) } // testMongoLeafCluster tests a scenario where a user connects // to a Mongo database running in a leaf cluster. func (p *DatabasePack) testMongoLeafCluster(t *testing.T) { + ctx := t.Context() // Connect to the database service in root cluster. - client, err := mongodb.MakeTestClient(context.Background(), common.TestClientConfig{ + client, err := mongodb.MakeTestClient(ctx, common.TestClientConfig{ AuthClient: p.Root.Cluster.GetSiteAPI(p.Root.Cluster.Secrets.SiteName), AuthServer: p.Root.Cluster.Process.GetAuthServer(), Address: p.Root.Cluster.Web, // Connecting via root cluster. @@ -426,11 +449,11 @@ func (p *DatabasePack) testMongoLeafCluster(t *testing.T) { require.NoError(t, err) // Execute a query. - _, err = client.Database("test").Collection("test").Find(context.Background(), bson.M{}) + _, err = client.Database("test").Collection("test").Find(ctx, bson.M{}) require.NoError(t, err) // Disconnect. - err = client.Disconnect(context.Background()) + err = client.Disconnect(ctx) require.NoError(t, err) } diff --git a/integration/db/fixture.go b/integration/db/fixture.go index df2f65a5d2a49..6bcd37ba6cf70 100644 --- a/integration/db/fixture.go +++ b/integration/db/fixture.go @@ -28,6 +28,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -116,7 +117,7 @@ func (pack *databaseClusterPack) StartDatabaseServices(t *testing.T, clock clock pack.MongoService = servicecfg.Database{ Name: fmt.Sprintf("%s-mongo", pack.name), Protocol: defaults.ProtocolMongoDB, - URI: pack.mongoAddr, + URI: fmt.Sprintf("mongodb://%s/?heartbeatintervalms=500", pack.mongoAddr), } cassandaListener, pack.cassandraAddr = mustListen(t) @@ -174,7 +175,7 @@ func (pack *databaseClusterPack) StartDatabaseServices(t *testing.T, clock clock AuthClient: pack.dbAuthClient, Name: pack.MongoService.Name, Listener: mongoListener, - }) + }, mongodb.TestServerSetFakeUserAuthError("nonexistent", trace.NotFound("user does not exist"))) require.NoError(t, err) go pack.mongo.Serve() t.Cleanup(func() { pack.mongo.Close() }) @@ -339,6 +340,14 @@ func SetupDatabaseTest(t *testing.T, options ...TestOptionFunc) *DatabasePack { // Setup users and roles on both clusters. p.setupUsersAndRoles(t) + // disable health checks to reduce log noise during tests + defaultHCC := services.VirtualDefaultHealthCheckConfigDB() + defaultHCC.GetSpec().GetMatch().Disabled = true + _, err = p.Root.Cluster.Process.GetAuthServer().UpsertHealthCheckConfig(ctx, defaultHCC) + require.NoError(t, err) + _, err = p.Leaf.Cluster.Process.GetAuthServer().UpsertHealthCheckConfig(ctx, defaultHCC) + require.NoError(t, err) + // Update root's certificate authority on leaf to configure role mapping. ca, err := p.Leaf.Cluster.Process.GetAuthServer().GetCertAuthority(ctx, types.CertAuthID{ Type: types.UserCA, diff --git a/lib/srv/db/mongodb/connect.go b/lib/srv/db/mongodb/connect.go index 15e0f81f6b6df..9146c63df60c8 100644 --- a/lib/srv/db/mongodb/connect.go +++ b/lib/srv/db/mongodb/connect.go @@ -69,17 +69,15 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (drive if err != nil { return nil, nil, trace.Wrap(err) } - err = top.Connect() - if err != nil { - return nil, nil, trace.Wrap(err) - } - server, err := top.SelectServer(ctx, selector) - if err != nil { + if err := top.Connect(); err != nil { + e.Log.DebugContext(e.Context, "Failed to connect topology", "error", err) return nil, nil, trace.Wrap(err) } - e.Log.DebugContext(e.Context, "Connecting to cluster.", "topology", top, "server", server) - conn, err := server.Connection(ctx) + conn, err := e.selectServerConn(ctx, top, selector) if err != nil { + if err := top.Disconnect(ctx); err != nil { + e.Log.WarnContext(e.Context, "Failed to close topology", "error", err) + } return nil, nil, trace.Wrap(err) } @@ -94,6 +92,23 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (drive return conn, closeFn, nil } +func (e *Engine) selectServerConn(ctx context.Context, top *topology.Topology, selector description.ServerSelector) (driver.Connection, error) { + e.Log.DebugContext(e.Context, "Selecting server from topology", "topology", top) + server, err := top.SelectServer(ctx, selector) + if err != nil { + e.Log.DebugContext(e.Context, "failed to select server", "error", err) + return nil, trace.Wrap(err) + } + + e.Log.DebugContext(e.Context, "Connecting to server", "server", server) + conn, err := server.Connection(ctx) + if err != nil { + e.Log.DebugContext(e.Context, "failed to connect to server", "error", err) + return nil, trace.Wrap(err) + } + return conn, nil +} + // getTopologyOptions constructs topology options for connecting to a MongoDB server. func (e *Engine) getTopologyOptions(ctx context.Context, sessionCtx *common.Session) (*topology.Config, description.ServerSelector, error) { clientCfg, err := makeClientOptionsFromDatabaseURI(sessionCtx.Database.GetURI()) diff --git a/lib/srv/db/mongodb/protocol/errors.go b/lib/srv/db/mongodb/protocol/errors.go index e7cd94b6ca88e..887f90a33dc23 100644 --- a/lib/srv/db/mongodb/protocol/errors.go +++ b/lib/srv/db/mongodb/protocol/errors.go @@ -28,16 +28,7 @@ import ( // ReplyError sends error wire message to the client. func ReplyError(clientConn net.Conn, replyTo Message, clientErr error) (err error) { - if msgCompressed, ok := replyTo.(*MessageOpCompressed); ok { - replyTo = msgCompressed.GetOriginal() - } - var errMessage Message - switch replyTo.(type) { - case *MessageOpMsg: // When client request is OP_MSG, reply should be OP_MSG as well. - errMessage, err = makeOpMsgError(clientErr) - default: // Send OP_REPLY otherwise. - errMessage, err = makeOpReplyError(clientErr) - } + errMessage, err := MakeErrorMessage(replyTo, clientErr) if err != nil { return trace.Wrap(err) } @@ -54,6 +45,20 @@ func ReplyError(clientConn net.Conn, replyTo Message, clientErr error) (err erro return nil } +// MakeErrorMessage builds an error message as either an OP_REPLY or OP_MSG +// depending on the message being replied to. +func MakeErrorMessage(replyTo Message, clientErr error) (Message, error) { + if msgCompressed, ok := replyTo.(*MessageOpCompressed); ok { + replyTo = msgCompressed.GetOriginal() + } + switch replyTo.(type) { + case *MessageOpMsg: // When client request is OP_MSG, reply should be OP_MSG as well. + return makeOpMsgError(clientErr) + default: // Send OP_REPLY otherwise. + return makeOpReplyError(clientErr) + } +} + // makeOpReplyError builds a OP_REPLY error wire message. func makeOpReplyError(err error) (Message, error) { document, err := bson.Marshal(bson.M{ diff --git a/lib/srv/db/mongodb/test.go b/lib/srv/db/mongodb/test.go index e216dea4f1929..244bc179985d5 100644 --- a/lib/srv/db/mongodb/test.go +++ b/lib/srv/db/mongodb/test.go @@ -98,6 +98,9 @@ type TestServer struct { // saslConversionTracker map to track which SASL mechanism is being used by // the conversion ID. saslConversationTracker sync.Map + + // fakeUserAuthError maps fake errors to specific users during auth handshake + fakeUserAuthError map[string]error } // TestServerOption allows to set test server options. @@ -117,6 +120,15 @@ func TestServerMaxMessageSize(maxMessageSize uint32) TestServerOption { } } +func TestServerSetFakeUserAuthError(user string, fakeErr error) TestServerOption { + return func(ts *TestServer) { + if !strings.HasPrefix(user, "CN=") { + user = "CN=" + user + } + ts.fakeUserAuthError[user] = fakeErr + } +} + // NewTestServer returns a new instance of a test MongoDB server. func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (svr *TestServer, err error) { err = config.CheckAndSetDefaults() @@ -133,6 +145,7 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (sv if err != nil { return nil, trace.Wrap(err) } + tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven log := logtest.With(teleport.ComponentKey, defaults.ProtocolMongoDB, "name", config.Name, ) @@ -147,6 +160,7 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (sv userEventsCh: make(chan UserEvent, 100), users: make(map[string]userWithTracking), }, + fakeUserAuthError: map[string]error{}, } for _, o := range opts { o(server) @@ -186,7 +200,7 @@ func (s *TestServer) Serve() error { // handleConnection receives Mongo wire messages from the client connection // and sends back the response messages. func (s *TestServer) handleConnection(conn net.Conn) error { - release, err := s.trackUserConnection(conn) + username, release, err := s.trackUserConnection(conn) if err != nil { return trace.Wrap(err) } @@ -199,7 +213,7 @@ func (s *TestServer) handleConnection(conn net.Conn) error { if err != nil { return trace.Wrap(err) } - reply, err := s.handleMessage(message) + reply, err := s.handleMessage(username, message) if err != nil { return trace.Wrap(err) } @@ -211,7 +225,7 @@ func (s *TestServer) handleConnection(conn net.Conn) error { } // handleMessage makes response for the provided command received from client. -func (s *TestServer) handleMessage(message protocol.Message) (protocol.Message, error) { +func (s *TestServer) handleMessage(username string, message protocol.Message) (protocol.Message, error) { command, err := message.GetCommand() if err != nil { return nil, trace.Wrap(err) @@ -222,7 +236,7 @@ func (s *TestServer) handleMessage(message protocol.Message) (protocol.Message, case commandHello: return s.handleHello(message) case commandAuth: - return s.handleAuth(message) + return s.handleAuth(username, message) case commandPing: return s.handlePing(message) case commandFind: @@ -252,7 +266,7 @@ func (s *TestServer) handleMessage(message protocol.Message) (protocol.Message, } // handleAuth makes response to the client's "authenticate" command. -func (s *TestServer) handleAuth(message protocol.Message) (protocol.Message, error) { +func (s *TestServer) handleAuth(username string, message protocol.Message) (protocol.Message, error) { // If authentication token is set on the server, it should only use SASL. // This avoid false positives where Teleport uses the wrong authentication // method. @@ -268,6 +282,9 @@ func (s *TestServer) handleAuth(message protocol.Message) (protocol.Message, err if command != commandAuth { return nil, trace.BadParameter("expected authenticate command, got: %s", message) } + if err, ok := s.fakeUserAuthError[username]; ok && err != nil { + return protocol.MakeErrorMessage(message, err) + } authReply, err := makeOKReply() if err != nil { return nil, trace.Wrap(err) @@ -624,16 +641,16 @@ func (t *usersTracker) UserEventsCh() <-chan UserEvent { return t.userEventsCh } -func (t *usersTracker) trackUserConnection(conn net.Conn) (func(), error) { +func (t *usersTracker) trackUserConnection(conn net.Conn) (string, func(), error) { tlsConn, ok := conn.(*tls.Conn) if !ok { - return func() {}, nil + return "", func() {}, nil } if err := tlsConn.Handshake(); err != nil { - return nil, trace.Wrap(err) + return "", nil, trace.Wrap(err) } else if len(tlsConn.ConnectionState().PeerCertificates) == 0 { - return func() {}, nil + return "", func() {}, nil } username := "CN=" + tlsConn.ConnectionState().PeerCertificates[0].Subject.CommonName @@ -644,7 +661,7 @@ func (t *usersTracker) trackUserConnection(conn net.Conn) (func(), error) { user.activeConnections[conn] = struct{}{} } - return func() { + return username, func() { // Untrack per-user active connections. t.usersMu.Lock() defer t.usersMu.Unlock()