Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,14 @@ func withDeniedDBLabels(labels types.Labels) roleOptFn {
}
}

func withClientIdleTimeout(clientIdleTimeout time.Duration) roleOptFn {
return func(role types.Role) {
opts := role.GetOptions()
opts.ClientIdleTimeout = types.NewDuration(clientIdleTimeout)
role.SetOptions(opts)
}
}

// createUserAndRole creates Teleport user and role with specified names
// and allowed database users/names properties.
func (c *testContext) createUserAndRole(ctx context.Context, t testing.TB, userName, roleName string, dbUsers, dbNames []string, roleOpts ...roleOptFn) (types.User, types.Role) {
Expand Down
15 changes: 8 additions & 7 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,14 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro
}
}()
}()

// Wrap a client connection into monitor that auto-terminates
// idle connection and connection with expired cert.
ctx, clientConn, err = s.cfg.ConnectionMonitor.MonitorConn(cancelCtx, sessionCtx.AuthContext, clientConn)
if err != nil {
return trace.Wrap(err)
}

engine, err := s.dispatch(sessionCtx, rec, clientConn)
if err != nil {
return trace.Wrap(err)
Expand All @@ -995,13 +1003,6 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro
}
}()

// Wrap a client connection into monitor that auto-terminates
// idle connection and connection with expired cert.
ctx, clientConn, err = s.cfg.ConnectionMonitor.MonitorConn(cancelCtx, sessionCtx.AuthContext, clientConn)
if err != nil {
return trace.Wrap(err)
}

// TODO(jakule): LoginIP should be required starting from 10.0.
clientIP := sessionCtx.Identity.LoginIP
if clientIP != "" {
Expand Down
52 changes: 52 additions & 0 deletions lib/srv/db/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/services"
)
Expand Down Expand Up @@ -184,6 +185,57 @@ func TestDatabaseServerLimiting(t *testing.T) {
})
}

func TestDatabaseServerAutoDisconnect(t *testing.T) {
const (
user = "bob"
role = "admin"
dbName = "postgres"
dbUser = user
)

ctx := context.Background()
allowDbUsers := []string{types.Wildcard}
allowDbNames := []string{types.Wildcard}

testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres"))

go testCtx.startHandlingConnections()
t.Cleanup(func() {
require.NoError(t, testCtx.Close())
})

const clientIdleTimeout = time.Second * 30

// create user/role with client idle timeout
testCtx.createUserAndRole(ctx, t, user, role, allowDbUsers, allowDbNames, withClientIdleTimeout(clientIdleTimeout))

// connect
pgConn, err := testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName)
require.NoError(t, err)

// immediate query should work
_, err = pgConn.Exec(ctx, "select 1").ReadAll()
require.NoError(t, err)

// advance clock several times, perform query.
// the activity should update the idle activity timer.
for i := 0; i < 10; i++ {
testCtx.clock.Advance(clientIdleTimeout / 2)
_, err = pgConn.Exec(ctx, "select 1").ReadAll()
require.NoErrorf(t, err, "failed on iteration %v", i+1)
}

// advance clock by full idle timeout, expect the client to be disconnected automatically.
testCtx.clock.Advance(clientIdleTimeout)
waitForEvent(t, testCtx, events.ClientDisconnectCode)

// expect failure after timeout.
_, err = pgConn.Exec(ctx, "select 1").ReadAll()
require.Error(t, err)

require.NoError(t, pgConn.Close(ctx))
}

func TestHeartbeatEvents(t *testing.T) {
ctx := context.Background()

Expand Down