Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture postgres extended protocol queries in audit log #6303

Merged
merged 1 commit into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions api/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ const (

// DarwinOS is the GOOS constant for Apple macOS/darwin.
DarwinOS = "darwin"

// UseOfClosedNetworkConnection is a special string some parts of
// go standard lib are using that is the only way to identify some errors
//
// TODO(r0mant): See if we can use net.ErrClosed and errors.Is() instead.
UseOfClosedNetworkConnection = "use of closed network connection"
awly marked this conversation as resolved.
Show resolved Hide resolved
)

// SecondFactorType is the type of 2FA authentication.
Expand Down
593 changes: 322 additions & 271 deletions api/types/events/events.pb.go

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions api/types/events/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,9 @@ message DatabaseSessionQuery {
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// DatabaseQuery is the executed query string.
string DatabaseQuery = 5 [ (gogoproto.jsontag) = "db_query" ];
// DatabaseQueryParameters are the query parameters for prepared statements.
repeated string DatabaseQueryParameters = 6
[ (gogoproto.jsontag) = "db_query_parameters,omitempty" ];
}

// DatabaseSessionEnd is emitted when a user ends the database session.
Expand Down
13 changes: 12 additions & 1 deletion api/utils/sshutils/chconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"sync"
"time"

"github.com/gravitational/teleport/api/constants"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -111,7 +113,16 @@ func (c *ChConn) RemoteAddr() net.Addr {

// Read reads from the channel.
func (c *ChConn) Read(data []byte) (int, error) {
return c.reader.Read(data)
n, err := c.reader.Read(data)
// A lot of code relies on "use of closed network connection" error to
// gracefully handle terminated connections so convert the closed pipe
// error to it.
if err != nil && err == io.ErrClosedPipe {
return n, trace.ConnectionProblem(err, constants.UseOfClosedNetworkConnection)
}
// Do not wrap the error to avoid masking the underlying error such as
// timeout error which is returned when read deadline is exceeded.
return n, err
}

// SetDeadline sets a connection deadline.
Expand Down
43 changes: 20 additions & 23 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,26 @@ import (
// The following constants have been moved to /api/constants/constants.go, and are now
// imported here for backwards compatibility. DELETE IN 7.0.0
const (
Local = constants.Local
OIDC = constants.OIDC
SAML = constants.SAML
Github = constants.Github
HumanDateFormatSeconds = constants.HumanDateFormatSeconds
DefaultImplicitRole = constants.DefaultImplicitRole
APIDomain = constants.APIDomain
CertificateFormatStandard = constants.CertificateFormatStandard
DurationNever = constants.DurationNever
EnhancedRecordingMinKernel = constants.EnhancedRecordingMinKernel
EnhancedRecordingCommand = constants.EnhancedRecordingCommand
EnhancedRecordingDisk = constants.EnhancedRecordingDisk
EnhancedRecordingNetwork = constants.EnhancedRecordingNetwork
KeepAliveNode = constants.KeepAliveNode
KeepAliveApp = constants.KeepAliveApp
KeepAliveDatabase = constants.KeepAliveDatabase
WindowsOS = constants.WindowsOS
LinuxOS = constants.LinuxOS
DarwinOS = constants.DarwinOS
Local = constants.Local
OIDC = constants.OIDC
SAML = constants.SAML
Github = constants.Github
HumanDateFormatSeconds = constants.HumanDateFormatSeconds
DefaultImplicitRole = constants.DefaultImplicitRole
APIDomain = constants.APIDomain
CertificateFormatStandard = constants.CertificateFormatStandard
DurationNever = constants.DurationNever
EnhancedRecordingMinKernel = constants.EnhancedRecordingMinKernel
EnhancedRecordingCommand = constants.EnhancedRecordingCommand
EnhancedRecordingDisk = constants.EnhancedRecordingDisk
EnhancedRecordingNetwork = constants.EnhancedRecordingNetwork
KeepAliveNode = constants.KeepAliveNode
KeepAliveApp = constants.KeepAliveApp
KeepAliveDatabase = constants.KeepAliveDatabase
WindowsOS = constants.WindowsOS
LinuxOS = constants.LinuxOS
DarwinOS = constants.DarwinOS
UseOfClosedNetworkConnection = constants.UseOfClosedNetworkConnection
)

// WebAPIVersion is a current webapi version
Expand Down Expand Up @@ -647,10 +648,6 @@ const (
)

const (
// UseOfClosedNetworkConnection is a special string some parts of
// go standard lib are using that is the only way to identify some errors
UseOfClosedNetworkConnection = "use of closed network connection"

// NodeIsAmbiguous serves as an identifying error string indicating that
// the proxy subsystem found multiple nodes matching the specified hostname.
NodeIsAmbiguous = "err-node-is-ambiguous"
Expand Down
9 changes: 6 additions & 3 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,14 @@ const (
// AppSessionRequestEvent is an HTTP request and response.
AppSessionRequestEvent = "app.session.request"

// DatabaseSessionStartEvent indicates the start of a database session.
// DatabaseSessionStartEvent is emitted when a database client attempts
// to connect to a database.
DatabaseSessionStartEvent = "db.session.start"
// DatabaseSessionEndEvent indicates the end of a database session.
// DatabaseSessionEndEvent is emitted when a database client disconnects
// from a database.
DatabaseSessionEndEvent = "db.session.end"
// DatabaseSessionQueryEvent indicates a database query execution.
// DatabaseSessionQueryEvent is emitted when a database client executes
// a query.
DatabaseSessionQueryEvent = "db.session.query"

// SessionRejectedReasonMaxConnections indicates that a session.rejected event
Expand Down
6 changes: 3 additions & 3 deletions lib/events/filesessions/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload
return trace.Wrap(err)
}
defer func() {
if err := f.Close(); err != nil {
h.WithError(err).Errorf("Failed to close file %q.", uploadPath)
}
if err := utils.FSUnlock(f); err != nil {
h.WithError(err).Errorf("Failed to unlock filesystem lock.")
}
if err := f.Close(); err != nil {
h.WithError(err).Errorf("Failed to close file %q.", uploadPath)
}
awly marked this conversation as resolved.
Show resolved Hide resolved
}()

files := make([]*os.File, 0, len(parts))
Expand Down
155 changes: 79 additions & 76 deletions lib/srv/db/db_test.go → lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2020 Gravitational, Inc.
Copyright 2020-2021 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,6 +39,7 @@ import (
"github.com/jackc/pgconn"
"github.com/jonboulle/clockwork"
"github.com/pborman/uuid"
"github.com/siddontang/go-mysql/client"
"github.com/stretchr/testify/require"
)

Expand All @@ -53,19 +54,7 @@ func TestPostgresAccess(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })

// Start multiplexer.
go testCtx.mux.Serve()
// Start fake Postgres server.
go testCtx.postgresServer.Serve()
// Start database proxy server.
go testCtx.proxyServer.Serve(testCtx.mux.DB())
// Start database service server.
go func() {
for conn := range testCtx.proxyConn {
testCtx.server.HandleConnection(conn)
}
}()
go testCtx.startHandlingPostgresConnections()

tests := []struct {
desc string
Expand Down Expand Up @@ -149,19 +138,7 @@ func TestPostgresAccess(t *testing.T) {
require.NoError(t, err)

// Try to connect to the database as this user.
pgConn, err := postgres.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: testCtx.authClient,
AuthServer: testCtx.authServer,
Address: testCtx.mux.DB().Addr().String(),
Cluster: testCtx.clusterName,
Username: test.user,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: "postgres-test",
Protocol: defaults.ProtocolPostgres,
Username: test.dbUser,
Database: test.dbName,
},
})
pgConn, err := testCtx.postgresClient(ctx, test.user, test.dbUser, test.dbName)
if test.err != "" {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
Expand All @@ -188,17 +165,7 @@ func TestMySQLAccess(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })

// Start test MySQL server.
go testCtx.mysqlServer.Serve()
// Start MySQL proxy server.
go testCtx.proxyServer.ServeMySQL(testCtx.mysqlListener)
// Start database service server.
go func() {
for conn := range testCtx.proxyConn {
testCtx.server.HandleConnection(conn)
}
}()
go testCtx.startHandlingMySQLConnections()

tests := []struct {
// desc is the test case description.
Expand Down Expand Up @@ -257,18 +224,7 @@ func TestMySQLAccess(t *testing.T) {
require.NoError(t, err)

// Try to connect to the database as this user.
mysqlConn, err := mysql.MakeTestClient(common.TestClientConfig{
AuthClient: testCtx.authClient,
AuthServer: testCtx.authServer,
Address: testCtx.mysqlListener.Addr().String(),
Cluster: testCtx.clusterName,
Username: test.user,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: "mysql-test",
Protocol: defaults.ProtocolPostgres,
Username: test.dbUser,
},
})
mysqlConn, err := testCtx.mysqlClient(test.user, test.dbUser)
if test.err != "" {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
Expand Down Expand Up @@ -309,19 +265,7 @@ func TestDatabaseAccessDisabled(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })

// Start multiplexer.
go testCtx.mux.Serve()
// Start fake Postgres server.
go testCtx.postgresServer.Serve()
// Start database proxy server.
go testCtx.proxyServer.Serve(testCtx.mux.DB())
// Start database service server.
go func() {
for conn := range testCtx.proxyConn {
testCtx.server.HandleConnection(conn)
}
}()
go testCtx.startHandlingPostgresConnections()

userName := "alice"
roleName := "admin"
Expand All @@ -338,19 +282,7 @@ func TestDatabaseAccessDisabled(t *testing.T) {
require.NoError(t, err)

// Try to connect to the database as this user.
_, err = postgres.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: testCtx.authClient,
AuthServer: testCtx.authServer,
Address: testCtx.mux.DB().Addr().String(),
Cluster: testCtx.clusterName,
Username: userName,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: "postgres-test",
Protocol: defaults.ProtocolPostgres,
Username: dbUser,
Database: dbName,
},
})
_, err = testCtx.postgresClient(ctx, userName, dbUser, dbName)
require.Error(t, err)
require.Contains(t, err.Error(), "this Teleport cluster doesn't support database access")
}
Expand All @@ -369,6 +301,66 @@ type testContext struct {
server *Server
postgresDBServer types.DatabaseServer
mysqlDBServer types.DatabaseServer
emitter *testEmitter
}

func (c *testContext) startHandlingPostgresConnections() {
// Start multiplexer.
go c.mux.Serve()
// Start fake Postgres server.
go c.postgresServer.Serve()
// Start database proxy server.
go c.proxyServer.Serve(c.mux.DB())
// Start handling Postgres connection on the database server.
for conn := range c.proxyConn {
c.server.HandleConnection(conn)
}
}

// postgresClient connects to test Postgres through database access as a
// specified Teleport user and database account.
func (c *testContext) postgresClient(ctx context.Context, teleportUser, dbUser, dbName string) (*pgconn.PgConn, error) {
return postgres.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: c.mux.DB().Addr().String(),
Cluster: c.clusterName,
Username: teleportUser,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: c.postgresDBServer.GetName(),
Protocol: defaults.ProtocolPostgres,
Username: dbUser,
Database: dbName,
},
})
}

func (c *testContext) startHandlingMySQLConnections() {
// Start test MySQL server.
go c.mysqlServer.Serve()
// Start MySQL proxy server.
go c.proxyServer.ServeMySQL(c.mysqlListener)
// Start handling MySQL connections on the database server.
for conn := range c.proxyConn {
c.server.HandleConnection(conn)
}
}

// mysqlClient connects to test MySQL through database access as a specified
// Teleport user and database account.
func (c *testContext) mysqlClient(teleportUser, dbUser string) (*client.Conn, error) {
return mysql.MakeTestClient(common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: c.mysqlListener.Addr().String(),
Cluster: c.clusterName,
Username: teleportUser,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: c.mysqlDBServer.GetName(),
Protocol: defaults.ProtocolMySQL,
Username: dbUser,
},
})
}

// Close closes all resources associated with the test context.
Expand Down Expand Up @@ -487,6 +479,9 @@ func setupTestContext(ctx context.Context, t *testing.T) *testContext {
})
require.NoError(t, err)

// Create test audit events emitter.
emitter := newTestEmitter()

// Create database service server.
server, err := New(ctx, Config{
Clock: clockwork.NewFakeClockAt(time.Now()),
Expand All @@ -498,6 +493,13 @@ func setupTestContext(ctx context.Context, t *testing.T) *testContext {
Servers: []types.DatabaseServer{postgresDBServer, mysqlDBServer},
TLSConfig: tlsConfig,
GetRotation: func(teleport.Role) (*types.Rotation, error) { return &types.Rotation{}, nil },
NewAudit: func(common.AuditConfig) (common.Audit, error) {
// Use the same audit logger implementation but substitute the
// underlying emitter so events can be tracked in tests.
return common.NewAudit(common.AuditConfig{
Emitter: emitter,
})
},
})
require.NoError(t, err)

Expand All @@ -515,6 +517,7 @@ func setupTestContext(ctx context.Context, t *testing.T) *testContext {
tlsServer: tlsServer,
authServer: tlsServer.Auth(),
authClient: dbAuthClient,
emitter: emitter,
}
}

Expand Down
Loading