Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions integration/db/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/entitlements"
Expand Down Expand Up @@ -357,9 +358,8 @@ func (p *DatabasePack) testMongoRootCluster(t *testing.T) {
// testMongoConnectionCount tests if mongo service releases
// resource after a mongo client disconnect.
func (p *DatabasePack) testMongoConnectionCount(t *testing.T) {
connectMongoClient := func(t *testing.T) (serverConnectionCount int32) {
// Connect to the database service in root cluster.
client, err := mongodb.MakeTestClient(context.Background(), common.TestClientConfig{
makeClient := func(t *testing.T, dbUser string) (*mongo.Client, error) {
return mongodb.MakeTestClient(t.Context(), common.TestClientConfig{
AuthClient: p.Root.Cluster.GetSiteAPI(p.Root.Cluster.Secrets.SiteName),
AuthServer: p.Root.Cluster.Process.GetAuthServer(),
Address: p.Root.Cluster.Web,
Expand All @@ -368,50 +368,73 @@ func (p *DatabasePack) testMongoConnectionCount(t *testing.T) {
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: p.Root.MongoService.Name,
Protocol: p.Root.MongoService.Protocol,
Username: "admin",
Username: dbUser,
},
})
}

connectMongoClient := func(t *testing.T, dbUser string) (serverConnectionCount int32) {
client, err := makeClient(t, dbUser)
require.NoError(t, err)

// Execute a query.
_, err = client.Database("test").Collection("test").Find(context.Background(), bson.M{})
_, err = client.Database("test").Collection("test").Find(t.Context(), bson.M{})
require.NoError(t, err)

// Get a server connection count before disconnect.
serverConnectionCount = p.Root.mongo.GetActiveConnectionsCount()

// Disconnect.
err = client.Disconnect(context.Background())
err = client.Disconnect(t.Context())
require.NoError(t, err)

return serverConnectionCount
}

// Get connection count while the first client is connected.
initialConnectionCount := connectMongoClient(t)
initialConnectionCount := connectMongoClient(t, "admin")

// Check if active connections count is not growing over time when new
// clients connect to the mongo server.
clientCount := 8
for range clientCount {
// Note that connection count per client fluctuates between 6 and 9.
// Use InDelta to avoid flaky test.
require.InDelta(t, initialConnectionCount, connectMongoClient(t), 3)
got := connectMongoClient(t, "admin")
require.InDelta(t, initialConnectionCount, got, 3)
}

client, err := makeClient(t, "nonexistent")
if !assert.Error(t, err) {
require.NoError(t, client.Disconnect(t.Context()))
return
}
require.ErrorContains(t, err, "does not exist")

// Wait until the server reports no more connections. This usually happens
// really quick but wait a little longer just in case.
waitUntilNoConnections := func() bool {
return p.Root.mongo.GetActiveConnectionsCount() == 0
}
require.Eventually(t, waitUntilNoConnections, 5*time.Second, 100*time.Millisecond)
var activeConns int32
require.Eventually(t, func() bool {
conns := p.Root.mongo.GetActiveConnectionsCount()
if conns != activeConns {
t.Logf("active connections to MongoDB changed from %d to %d", activeConns, conns)
Comment thread
greedy52 marked this conversation as resolved.
activeConns = conns
}
return activeConns == 0
}, 5*time.Second, 100*time.Millisecond)

start := time.Now()
require.Never(t, func() bool {
return p.Root.mongo.GetActiveConnectionsCount() > 0
}, time.Second*5, time.Millisecond*100, "no connections should be left open. Found after %v", time.Since(start))
}

// testMongoLeafCluster tests a scenario where a user connects
// to a Mongo database running in a leaf cluster.
func (p *DatabasePack) testMongoLeafCluster(t *testing.T) {
ctx := t.Context()
// Connect to the database service in root cluster.
client, err := mongodb.MakeTestClient(context.Background(), common.TestClientConfig{
client, err := mongodb.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: p.Root.Cluster.GetSiteAPI(p.Root.Cluster.Secrets.SiteName),
AuthServer: p.Root.Cluster.Process.GetAuthServer(),
Address: p.Root.Cluster.Web, // Connecting via root cluster.
Expand All @@ -426,11 +449,11 @@ func (p *DatabasePack) testMongoLeafCluster(t *testing.T) {
require.NoError(t, err)

// Execute a query.
_, err = client.Database("test").Collection("test").Find(context.Background(), bson.M{})
_, err = client.Database("test").Collection("test").Find(ctx, bson.M{})
require.NoError(t, err)

// Disconnect.
err = client.Disconnect(context.Background())
err = client.Disconnect(ctx)
require.NoError(t, err)
}

Expand Down
13 changes: 11 additions & 2 deletions integration/db/fixture.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -116,7 +117,7 @@ func (pack *databaseClusterPack) StartDatabaseServices(t *testing.T, clock clock
pack.MongoService = servicecfg.Database{
Name: fmt.Sprintf("%s-mongo", pack.name),
Protocol: defaults.ProtocolMongoDB,
URI: pack.mongoAddr,
URI: fmt.Sprintf("mongodb://%s/?heartbeatintervalms=500", pack.mongoAddr),
}

cassandaListener, pack.cassandraAddr = mustListen(t)
Expand Down Expand Up @@ -174,7 +175,7 @@ func (pack *databaseClusterPack) StartDatabaseServices(t *testing.T, clock clock
AuthClient: pack.dbAuthClient,
Name: pack.MongoService.Name,
Listener: mongoListener,
})
}, mongodb.TestServerSetFakeUserAuthError("nonexistent", trace.NotFound("user does not exist")))
require.NoError(t, err)
go pack.mongo.Serve()
t.Cleanup(func() { pack.mongo.Close() })
Expand Down Expand Up @@ -339,6 +340,14 @@ func SetupDatabaseTest(t *testing.T, options ...TestOptionFunc) *DatabasePack {
// Setup users and roles on both clusters.
p.setupUsersAndRoles(t)

// disable health checks to reduce log noise during tests
defaultHCC := services.VirtualDefaultHealthCheckConfigDB()
defaultHCC.GetSpec().GetMatch().Disabled = true
_, err = p.Root.Cluster.Process.GetAuthServer().UpsertHealthCheckConfig(ctx, defaultHCC)
require.NoError(t, err)
_, err = p.Leaf.Cluster.Process.GetAuthServer().UpsertHealthCheckConfig(ctx, defaultHCC)
require.NoError(t, err)

// Update root's certificate authority on leaf to configure role mapping.
ca, err := p.Leaf.Cluster.Process.GetAuthServer().GetCertAuthority(ctx, types.CertAuthID{
Type: types.UserCA,
Expand Down
31 changes: 23 additions & 8 deletions lib/srv/db/mongodb/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,15 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (drive
if err != nil {
return nil, nil, trace.Wrap(err)
}
err = top.Connect()
if err != nil {
return nil, nil, trace.Wrap(err)
}
server, err := top.SelectServer(ctx, selector)
if err != nil {
if err := top.Connect(); err != nil {
e.Log.DebugContext(e.Context, "Failed to connect topology", "error", err)
return nil, nil, trace.Wrap(err)
}
e.Log.DebugContext(e.Context, "Connecting to cluster.", "topology", top, "server", server)
conn, err := server.Connection(ctx)
conn, err := e.selectServerConn(ctx, top, selector)
if err != nil {
if err := top.Disconnect(ctx); err != nil {
e.Log.WarnContext(e.Context, "Failed to close topology", "error", err)
}
Comment on lines +78 to +80
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the line that actually fixes the leak.
We just needed to disconnect from topology monitoring on errors from server selection or auth handshake.

return nil, nil, trace.Wrap(err)
}

Expand All @@ -94,6 +92,23 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (drive
return conn, closeFn, nil
}

func (e *Engine) selectServerConn(ctx context.Context, top *topology.Topology, selector description.ServerSelector) (driver.Connection, error) {
e.Log.DebugContext(e.Context, "Selecting server from topology", "topology", top)
server, err := top.SelectServer(ctx, selector)
if err != nil {
e.Log.DebugContext(e.Context, "failed to select server", "error", err)
return nil, trace.Wrap(err)
}

e.Log.DebugContext(e.Context, "Connecting to server", "server", server)
conn, err := server.Connection(ctx)
if err != nil {
e.Log.DebugContext(e.Context, "failed to connect to server", "error", err)
return nil, trace.Wrap(err)
}
return conn, nil
}

// getTopologyOptions constructs topology options for connecting to a MongoDB server.
func (e *Engine) getTopologyOptions(ctx context.Context, sessionCtx *common.Session) (*topology.Config, description.ServerSelector, error) {
clientCfg, err := makeClientOptionsFromDatabaseURI(sessionCtx.Database.GetURI())
Expand Down
25 changes: 15 additions & 10 deletions lib/srv/db/mongodb/protocol/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,7 @@ import (

// ReplyError sends error wire message to the client.
func ReplyError(clientConn net.Conn, replyTo Message, clientErr error) (err error) {
if msgCompressed, ok := replyTo.(*MessageOpCompressed); ok {
replyTo = msgCompressed.GetOriginal()
}
var errMessage Message
switch replyTo.(type) {
case *MessageOpMsg: // When client request is OP_MSG, reply should be OP_MSG as well.
errMessage, err = makeOpMsgError(clientErr)
default: // Send OP_REPLY otherwise.
errMessage, err = makeOpReplyError(clientErr)
}
errMessage, err := MakeErrorMessage(replyTo, clientErr)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -54,6 +45,20 @@ func ReplyError(clientConn net.Conn, replyTo Message, clientErr error) (err erro
return nil
}

// MakeErrorMessage builds an error message as either an OP_REPLY or OP_MSG
// depending on the message being replied to.
func MakeErrorMessage(replyTo Message, clientErr error) (Message, error) {
if msgCompressed, ok := replyTo.(*MessageOpCompressed); ok {
replyTo = msgCompressed.GetOriginal()
}
switch replyTo.(type) {
case *MessageOpMsg: // When client request is OP_MSG, reply should be OP_MSG as well.
return makeOpMsgError(clientErr)
default: // Send OP_REPLY otherwise.
return makeOpReplyError(clientErr)
}
}

// makeOpReplyError builds a OP_REPLY error wire message.
func makeOpReplyError(err error) (Message, error) {
document, err := bson.Marshal(bson.M{
Expand Down
37 changes: 27 additions & 10 deletions lib/srv/db/mongodb/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ type TestServer struct {
// saslConversionTracker map to track which SASL mechanism is being used by
// the conversion ID.
saslConversationTracker sync.Map

// fakeUserAuthError maps fake errors to specific users during auth handshake
fakeUserAuthError map[string]error
}

// TestServerOption allows to set test server options.
Expand All @@ -117,6 +120,15 @@ func TestServerMaxMessageSize(maxMessageSize uint32) TestServerOption {
}
}

func TestServerSetFakeUserAuthError(user string, fakeErr error) TestServerOption {
return func(ts *TestServer) {
if !strings.HasPrefix(user, "CN=") {
user = "CN=" + user
}
ts.fakeUserAuthError[user] = fakeErr
}
}

// NewTestServer returns a new instance of a test MongoDB server.
func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (svr *TestServer, err error) {
err = config.CheckAndSetDefaults()
Expand All @@ -133,6 +145,7 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (sv
if err != nil {
return nil, trace.Wrap(err)
}
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
log := logtest.With(teleport.ComponentKey, defaults.ProtocolMongoDB,
"name", config.Name,
)
Expand All @@ -147,6 +160,7 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (sv
userEventsCh: make(chan UserEvent, 100),
users: make(map[string]userWithTracking),
},
fakeUserAuthError: map[string]error{},
}
for _, o := range opts {
o(server)
Expand Down Expand Up @@ -186,7 +200,7 @@ func (s *TestServer) Serve() error {
// handleConnection receives Mongo wire messages from the client connection
// and sends back the response messages.
func (s *TestServer) handleConnection(conn net.Conn) error {
release, err := s.trackUserConnection(conn)
username, release, err := s.trackUserConnection(conn)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -199,7 +213,7 @@ func (s *TestServer) handleConnection(conn net.Conn) error {
if err != nil {
return trace.Wrap(err)
}
reply, err := s.handleMessage(message)
reply, err := s.handleMessage(username, message)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -211,7 +225,7 @@ func (s *TestServer) handleConnection(conn net.Conn) error {
}

// handleMessage makes response for the provided command received from client.
func (s *TestServer) handleMessage(message protocol.Message) (protocol.Message, error) {
func (s *TestServer) handleMessage(username string, message protocol.Message) (protocol.Message, error) {
command, err := message.GetCommand()
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -222,7 +236,7 @@ func (s *TestServer) handleMessage(message protocol.Message) (protocol.Message,
case commandHello:
return s.handleHello(message)
case commandAuth:
return s.handleAuth(message)
return s.handleAuth(username, message)
case commandPing:
return s.handlePing(message)
case commandFind:
Expand Down Expand Up @@ -252,7 +266,7 @@ func (s *TestServer) handleMessage(message protocol.Message) (protocol.Message,
}

// handleAuth makes response to the client's "authenticate" command.
func (s *TestServer) handleAuth(message protocol.Message) (protocol.Message, error) {
func (s *TestServer) handleAuth(username string, message protocol.Message) (protocol.Message, error) {
// If authentication token is set on the server, it should only use SASL.
// This avoid false positives where Teleport uses the wrong authentication
// method.
Expand All @@ -268,6 +282,9 @@ func (s *TestServer) handleAuth(message protocol.Message) (protocol.Message, err
if command != commandAuth {
return nil, trace.BadParameter("expected authenticate command, got: %s", message)
}
if err, ok := s.fakeUserAuthError[username]; ok && err != nil {
return protocol.MakeErrorMessage(message, err)
}
authReply, err := makeOKReply()
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -624,16 +641,16 @@ func (t *usersTracker) UserEventsCh() <-chan UserEvent {
return t.userEventsCh
}

func (t *usersTracker) trackUserConnection(conn net.Conn) (func(), error) {
func (t *usersTracker) trackUserConnection(conn net.Conn) (string, func(), error) {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return func() {}, nil
return "", func() {}, nil
}

if err := tlsConn.Handshake(); err != nil {
return nil, trace.Wrap(err)
return "", nil, trace.Wrap(err)
} else if len(tlsConn.ConnectionState().PeerCertificates) == 0 {
return func() {}, nil
return "", func() {}, nil
}

username := "CN=" + tlsConn.ConnectionState().PeerCertificates[0].Subject.CommonName
Expand All @@ -644,7 +661,7 @@ func (t *usersTracker) trackUserConnection(conn net.Conn) (func(), error) {
user.activeConnections[conn] = struct{}{}
}

return func() {
return username, func() {
// Untrack per-user active connections.
t.usersMu.Lock()
defer t.usersMu.Unlock()
Expand Down
Loading