Skip to content

Commit

Permalink
Capture postgres extended protocol messages in audit log (#6303)
Browse files Browse the repository at this point in the history
r0mant authored Apr 14, 2021
1 parent 1880147 commit 8230d6e
Showing 18 changed files with 1,062 additions and 450 deletions.
6 changes: 6 additions & 0 deletions api/constants/constants.go
Original file line number Diff line number Diff line change
@@ -91,6 +91,12 @@ const (

// DarwinOS is the GOOS constant for Apple macOS/darwin.
DarwinOS = "darwin"

// UseOfClosedNetworkConnection is a special string some parts of
// go standard lib are using that is the only way to identify some errors
//
// TODO(r0mant): See if we can use net.ErrClosed and errors.Is() instead.
UseOfClosedNetworkConnection = "use of closed network connection"
)

// SecondFactorType is the type of 2FA authentication.
593 changes: 322 additions & 271 deletions api/types/events/events.pb.go

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions api/types/events/events.proto
Original file line number Diff line number Diff line change
@@ -1171,6 +1171,9 @@ message DatabaseSessionQuery {
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// DatabaseQuery is the executed query string.
string DatabaseQuery = 5 [ (gogoproto.jsontag) = "db_query" ];
// DatabaseQueryParameters are the query parameters for prepared statements.
repeated string DatabaseQueryParameters = 6
[ (gogoproto.jsontag) = "db_query_parameters,omitempty" ];
}

// DatabaseSessionEnd is emitted when a user ends the database session.
13 changes: 12 additions & 1 deletion api/utils/sshutils/chconn.go
Original file line number Diff line number Diff line change
@@ -22,6 +22,8 @@ import (
"sync"
"time"

"github.com/gravitational/teleport/api/constants"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
)
@@ -111,7 +113,16 @@ func (c *ChConn) RemoteAddr() net.Addr {

// Read reads from the channel.
func (c *ChConn) Read(data []byte) (int, error) {
return c.reader.Read(data)
n, err := c.reader.Read(data)
// A lot of code relies on "use of closed network connection" error to
// gracefully handle terminated connections so convert the closed pipe
// error to it.
if err != nil && err == io.ErrClosedPipe {
return n, trace.ConnectionProblem(err, constants.UseOfClosedNetworkConnection)
}
// Do not wrap the error to avoid masking the underlying error such as
// timeout error which is returned when read deadline is exceeded.
return n, err
}

// SetDeadline sets a connection deadline.
43 changes: 20 additions & 23 deletions constants.go
Original file line number Diff line number Diff line change
@@ -26,25 +26,26 @@ import (
// The following constants have been moved to /api/constants/constants.go, and are now
// imported here for backwards compatibility. DELETE IN 7.0.0
const (
Local = constants.Local
OIDC = constants.OIDC
SAML = constants.SAML
Github = constants.Github
HumanDateFormatSeconds = constants.HumanDateFormatSeconds
DefaultImplicitRole = constants.DefaultImplicitRole
APIDomain = constants.APIDomain
CertificateFormatStandard = constants.CertificateFormatStandard
DurationNever = constants.DurationNever
EnhancedRecordingMinKernel = constants.EnhancedRecordingMinKernel
EnhancedRecordingCommand = constants.EnhancedRecordingCommand
EnhancedRecordingDisk = constants.EnhancedRecordingDisk
EnhancedRecordingNetwork = constants.EnhancedRecordingNetwork
KeepAliveNode = constants.KeepAliveNode
KeepAliveApp = constants.KeepAliveApp
KeepAliveDatabase = constants.KeepAliveDatabase
WindowsOS = constants.WindowsOS
LinuxOS = constants.LinuxOS
DarwinOS = constants.DarwinOS
Local = constants.Local
OIDC = constants.OIDC
SAML = constants.SAML
Github = constants.Github
HumanDateFormatSeconds = constants.HumanDateFormatSeconds
DefaultImplicitRole = constants.DefaultImplicitRole
APIDomain = constants.APIDomain
CertificateFormatStandard = constants.CertificateFormatStandard
DurationNever = constants.DurationNever
EnhancedRecordingMinKernel = constants.EnhancedRecordingMinKernel
EnhancedRecordingCommand = constants.EnhancedRecordingCommand
EnhancedRecordingDisk = constants.EnhancedRecordingDisk
EnhancedRecordingNetwork = constants.EnhancedRecordingNetwork
KeepAliveNode = constants.KeepAliveNode
KeepAliveApp = constants.KeepAliveApp
KeepAliveDatabase = constants.KeepAliveDatabase
WindowsOS = constants.WindowsOS
LinuxOS = constants.LinuxOS
DarwinOS = constants.DarwinOS
UseOfClosedNetworkConnection = constants.UseOfClosedNetworkConnection
)

// WebAPIVersion is a current webapi version
@@ -647,10 +648,6 @@ const (
)

const (
// UseOfClosedNetworkConnection is a special string some parts of
// go standard lib are using that is the only way to identify some errors
UseOfClosedNetworkConnection = "use of closed network connection"

// NodeIsAmbiguous serves as an identifying error string indicating that
// the proxy subsystem found multiple nodes matching the specified hostname.
NodeIsAmbiguous = "err-node-is-ambiguous"
9 changes: 6 additions & 3 deletions lib/events/api.go
Original file line number Diff line number Diff line change
@@ -355,11 +355,14 @@ const (
// AppSessionRequestEvent is an HTTP request and response.
AppSessionRequestEvent = "app.session.request"

// DatabaseSessionStartEvent indicates the start of a database session.
// DatabaseSessionStartEvent is emitted when a database client attempts
// to connect to a database.
DatabaseSessionStartEvent = "db.session.start"
// DatabaseSessionEndEvent indicates the end of a database session.
// DatabaseSessionEndEvent is emitted when a database client disconnects
// from a database.
DatabaseSessionEndEvent = "db.session.end"
// DatabaseSessionQueryEvent indicates a database query execution.
// DatabaseSessionQueryEvent is emitted when a database client executes
// a query.
DatabaseSessionQueryEvent = "db.session.query"

// SessionRejectedReasonMaxConnections indicates that a session.rejected event
6 changes: 3 additions & 3 deletions lib/events/filesessions/filestream.go
Original file line number Diff line number Diff line change
@@ -119,12 +119,12 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload
return trace.Wrap(err)
}
defer func() {
if err := f.Close(); err != nil {
h.WithError(err).Errorf("Failed to close file %q.", uploadPath)
}
if err := utils.FSUnlock(f); err != nil {
h.WithError(err).Errorf("Failed to unlock filesystem lock.")
}
if err := f.Close(); err != nil {
h.WithError(err).Errorf("Failed to close file %q.", uploadPath)
}
}()

files := make([]*os.File, 0, len(parts))
155 changes: 79 additions & 76 deletions lib/srv/db/db_test.go → lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2020 Gravitational, Inc.
Copyright 2020-2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -39,6 +39,7 @@ import (
"github.com/jackc/pgconn"
"github.com/jonboulle/clockwork"
"github.com/pborman/uuid"
"github.com/siddontang/go-mysql/client"
"github.com/stretchr/testify/require"
)

@@ -53,19 +54,7 @@ func TestPostgresAccess(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })

// Start multiplexer.
go testCtx.mux.Serve()
// Start fake Postgres server.
go testCtx.postgresServer.Serve()
// Start database proxy server.
go testCtx.proxyServer.Serve(testCtx.mux.DB())
// Start database service server.
go func() {
for conn := range testCtx.proxyConn {
testCtx.server.HandleConnection(conn)
}
}()
go testCtx.startHandlingPostgresConnections()

tests := []struct {
desc string
@@ -149,19 +138,7 @@ func TestPostgresAccess(t *testing.T) {
require.NoError(t, err)

// Try to connect to the database as this user.
pgConn, err := postgres.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: testCtx.authClient,
AuthServer: testCtx.authServer,
Address: testCtx.mux.DB().Addr().String(),
Cluster: testCtx.clusterName,
Username: test.user,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: "postgres-test",
Protocol: defaults.ProtocolPostgres,
Username: test.dbUser,
Database: test.dbName,
},
})
pgConn, err := testCtx.postgresClient(ctx, test.user, test.dbUser, test.dbName)
if test.err != "" {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
@@ -188,17 +165,7 @@ func TestMySQLAccess(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })

// Start test MySQL server.
go testCtx.mysqlServer.Serve()
// Start MySQL proxy server.
go testCtx.proxyServer.ServeMySQL(testCtx.mysqlListener)
// Start database service server.
go func() {
for conn := range testCtx.proxyConn {
testCtx.server.HandleConnection(conn)
}
}()
go testCtx.startHandlingMySQLConnections()

tests := []struct {
// desc is the test case description.
@@ -257,18 +224,7 @@ func TestMySQLAccess(t *testing.T) {
require.NoError(t, err)

// Try to connect to the database as this user.
mysqlConn, err := mysql.MakeTestClient(common.TestClientConfig{
AuthClient: testCtx.authClient,
AuthServer: testCtx.authServer,
Address: testCtx.mysqlListener.Addr().String(),
Cluster: testCtx.clusterName,
Username: test.user,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: "mysql-test",
Protocol: defaults.ProtocolPostgres,
Username: test.dbUser,
},
})
mysqlConn, err := testCtx.mysqlClient(test.user, test.dbUser)
if test.err != "" {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
@@ -309,19 +265,7 @@ func TestDatabaseAccessDisabled(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })

// Start multiplexer.
go testCtx.mux.Serve()
// Start fake Postgres server.
go testCtx.postgresServer.Serve()
// Start database proxy server.
go testCtx.proxyServer.Serve(testCtx.mux.DB())
// Start database service server.
go func() {
for conn := range testCtx.proxyConn {
testCtx.server.HandleConnection(conn)
}
}()
go testCtx.startHandlingPostgresConnections()

userName := "alice"
roleName := "admin"
@@ -338,19 +282,7 @@ func TestDatabaseAccessDisabled(t *testing.T) {
require.NoError(t, err)

// Try to connect to the database as this user.
_, err = postgres.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: testCtx.authClient,
AuthServer: testCtx.authServer,
Address: testCtx.mux.DB().Addr().String(),
Cluster: testCtx.clusterName,
Username: userName,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: "postgres-test",
Protocol: defaults.ProtocolPostgres,
Username: dbUser,
Database: dbName,
},
})
_, err = testCtx.postgresClient(ctx, userName, dbUser, dbName)
require.Error(t, err)
require.Contains(t, err.Error(), "this Teleport cluster doesn't support database access")
}
@@ -369,6 +301,66 @@ type testContext struct {
server *Server
postgresDBServer types.DatabaseServer
mysqlDBServer types.DatabaseServer
emitter *testEmitter
}

func (c *testContext) startHandlingPostgresConnections() {
// Start multiplexer.
go c.mux.Serve()
// Start fake Postgres server.
go c.postgresServer.Serve()
// Start database proxy server.
go c.proxyServer.Serve(c.mux.DB())
// Start handling Postgres connection on the database server.
for conn := range c.proxyConn {
c.server.HandleConnection(conn)
}
}

// postgresClient connects to test Postgres through database access as a
// specified Teleport user and database account.
func (c *testContext) postgresClient(ctx context.Context, teleportUser, dbUser, dbName string) (*pgconn.PgConn, error) {
return postgres.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: c.mux.DB().Addr().String(),
Cluster: c.clusterName,
Username: teleportUser,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: c.postgresDBServer.GetName(),
Protocol: defaults.ProtocolPostgres,
Username: dbUser,
Database: dbName,
},
})
}

func (c *testContext) startHandlingMySQLConnections() {
// Start test MySQL server.
go c.mysqlServer.Serve()
// Start MySQL proxy server.
go c.proxyServer.ServeMySQL(c.mysqlListener)
// Start handling MySQL connections on the database server.
for conn := range c.proxyConn {
c.server.HandleConnection(conn)
}
}

// mysqlClient connects to test MySQL through database access as a specified
// Teleport user and database account.
func (c *testContext) mysqlClient(teleportUser, dbUser string) (*client.Conn, error) {
return mysql.MakeTestClient(common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: c.mysqlListener.Addr().String(),
Cluster: c.clusterName,
Username: teleportUser,
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: c.mysqlDBServer.GetName(),
Protocol: defaults.ProtocolMySQL,
Username: dbUser,
},
})
}

// Close closes all resources associated with the test context.
@@ -487,6 +479,9 @@ func setupTestContext(ctx context.Context, t *testing.T) *testContext {
})
require.NoError(t, err)

// Create test audit events emitter.
emitter := newTestEmitter()

// Create database service server.
server, err := New(ctx, Config{
Clock: clockwork.NewFakeClockAt(time.Now()),
@@ -498,6 +493,13 @@ func setupTestContext(ctx context.Context, t *testing.T) *testContext {
Servers: []types.DatabaseServer{postgresDBServer, mysqlDBServer},
TLSConfig: tlsConfig,
GetRotation: func(teleport.Role) (*types.Rotation, error) { return &types.Rotation{}, nil },
NewAudit: func(common.AuditConfig) (common.Audit, error) {
// Use the same audit logger implementation but substitute the
// underlying emitter so events can be tracked in tests.
return common.NewAudit(common.AuditConfig{
Emitter: emitter,
})
},
})
require.NoError(t, err)

@@ -515,6 +517,7 @@ func setupTestContext(ctx context.Context, t *testing.T) *testContext {
tlsServer: tlsServer,
authServer: tlsServer.Auth(),
authClient: dbAuthClient,
emitter: emitter,
}
}

152 changes: 152 additions & 0 deletions lib/srv/db/audit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package db

import (
"context"
"testing"
"time"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/auth"
libevents "github.com/gravitational/teleport/lib/events"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)

// TestAuditPostgres verifies proper audit events are emitted for Postgres
// connections.
func TestAuditPostgres(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingPostgresConnections()

_, role, err := auth.CreateUserAndRole(testCtx.tlsServer.Auth(), "alice", []string{"admin"})
require.NoError(t, err)

role.SetDatabaseNames(types.Allow, []string{"postgres"})
role.SetDatabaseUsers(types.Allow, []string{"postgres"})
err = testCtx.tlsServer.Auth().UpsertRole(ctx, role)
require.NoError(t, err)

// Access denied should trigger an unsuccessful session start event.
_, err = testCtx.postgresClient(ctx, "alice", "notpostgres", "notpostgres")
require.Error(t, err)
requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode)

// Connect should trigger successful session start event.
psql, err := testCtx.postgresClient(ctx, "alice", "postgres", "postgres")
require.NoError(t, err)
requireEvent(t, testCtx, libevents.DatabaseSessionStartCode)

// Simple query should trigger the query event.
_, err = psql.Exec(ctx, "select 1").ReadAll()
require.NoError(t, err)
requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1")

// Prepared statement execution should also trigger a query event.
result := psql.ExecParams(ctx, "select now()", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select now()")

// Closing connection should trigger session end event.
err = psql.Close(ctx)
require.NoError(t, err)
requireEvent(t, testCtx, libevents.DatabaseSessionEndCode)
}

// TestAuditMySQL verifies proper audit events are emitted for MySQL
// connections.
func TestAuditMySQL(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t)
t.Cleanup(func() { testCtx.Close() })
go testCtx.startHandlingMySQLConnections()

_, role, err := auth.CreateUserAndRole(testCtx.tlsServer.Auth(), "alice", []string{"admin"})
require.NoError(t, err)

role.SetDatabaseUsers(types.Allow, []string{"root"})
err = testCtx.tlsServer.Auth().UpsertRole(ctx, role)
require.NoError(t, err)

// Access denied should trigger an unsuccessful session start event.
_, err = testCtx.mysqlClient("alice", "notroot")
require.Error(t, err)
requireEvent(t, testCtx, libevents.DatabaseSessionStartFailureCode)

// Connect should trigger successful session start event.
mysql, err := testCtx.mysqlClient("alice", "root")
require.NoError(t, err)
requireEvent(t, testCtx, libevents.DatabaseSessionStartCode)

// Simple query should trigger the query event.
_, err = mysql.Execute("select 1")
require.NoError(t, err)
requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1")

// Closing connection should trigger session end event.
err = mysql.Close()
require.NoError(t, err)
requireEvent(t, testCtx, libevents.DatabaseSessionEndCode)
}

func requireEvent(t *testing.T, testCtx *testContext, code string) {
event := waitForEvent(t, testCtx, code)
require.Equal(t, code, event.GetCode())
}

func requireQueryEvent(t *testing.T, testCtx *testContext, code, query string) {
event := waitForEvent(t, testCtx, code)
require.Equal(t, code, event.GetCode())
require.Equal(t, query, event.(*events.DatabaseSessionQuery).DatabaseQuery)
}

func waitForEvent(t *testing.T, testCtx *testContext, code string) events.AuditEvent {
select {
case event := <-testCtx.emitter.eventsCh:
return event
case <-time.After(time.Second):
t.Fatalf("didn't receive %v event after 1 second", code)
}
return nil
}

// testEmitter pushes all received audit events into a channel.
type testEmitter struct {
eventsCh chan events.AuditEvent
log logrus.FieldLogger
}

// newTestEmitter returns a new instance of test emitter.
func newTestEmitter() *testEmitter {
return &testEmitter{
eventsCh: make(chan events.AuditEvent, 100),
log: logrus.WithField(trace.Component, "emitter"),
}
}

// EmitAuditEvent records the provided event in the test emitter.
func (e *testEmitter) EmitAuditEvent(ctx context.Context, event events.AuditEvent) error {
e.log.Infof("EmitAuditEvent(%v)", event)
e.eventsCh <- event
return nil
}
67 changes: 38 additions & 29 deletions lib/srv/db/common/audit.go
Original file line number Diff line number Diff line change
@@ -24,40 +24,54 @@ import (
libevents "github.com/gravitational/teleport/lib/events"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
)

// Audit defines an interface for database access audit events logger.
type Audit interface {
// OnSessionStart is called on successful/unsuccessful database session start.
OnSessionStart(ctx context.Context, session *Session, sessionErr error)
// OnSessionEnd is called when database session terminates.
OnSessionEnd(ctx context.Context, session *Session)
// OnQuery is called when a SQL statement is executed.
OnQuery(ctx context.Context, session *Session, query string, parameters ...string)
}

// AuditConfig is the audit events emitter configuration.
type AuditConfig struct {
// StreamWriter is used to emit audit events.
StreamWriter libevents.StreamWriter
// Emitter is used to emit audit events.
Emitter events.Emitter
}

// Check validates the config.
func (c *AuditConfig) Check() error {
if c.StreamWriter == nil {
return trace.BadParameter("missing StreamWriter")
if c.Emitter == nil {
return trace.BadParameter("missing Emitter")
}
return nil
}

// Audit provides methods for emitting database access audit events.
type Audit struct {
// audit provides methods for emitting database access audit events.
type audit struct {
// cfg is the audit events emitter configuration.
cfg AuditConfig
// log is used for logging
log logrus.FieldLogger
}

// NewAudit returns a new instance of the audit events emitter.
func NewAudit(config AuditConfig) (*Audit, error) {
func NewAudit(config AuditConfig) (Audit, error) {
if err := config.Check(); err != nil {
return nil, trace.Wrap(err)
}
return &Audit{
return &audit{
cfg: config,
log: logrus.WithField(trace.Component, "db:audit"),
}, nil
}

// OnSessionStart emits an audit event when database session starts.
func (a *Audit) OnSessionStart(ctx context.Context, session Session, sessionErr error) error {
func (a *audit) OnSessionStart(ctx context.Context, session *Session, sessionErr error) {
event := &events.DatabaseSessionStart{
Metadata: events.Metadata{
Type: libevents.DatabaseSessionStartEvent,
@@ -76,16 +90,16 @@ func (a *Audit) OnSessionStart(ctx context.Context, session Session, sessionErr
SessionID: session.ID,
WithMFA: session.Identity.MFAVerified,
},
Status: events.Status{
Success: true,
},
DatabaseMetadata: events.DatabaseMetadata{
DatabaseService: session.Server.GetName(),
DatabaseProtocol: session.Server.GetProtocol(),
DatabaseURI: session.Server.GetURI(),
DatabaseName: session.DatabaseName,
DatabaseUser: session.DatabaseUser,
},
Status: events.Status{
Success: true,
},
}
// If the database session wasn't started successfully, emit
// a failure event with error details.
@@ -97,16 +111,12 @@ func (a *Audit) OnSessionStart(ctx context.Context, session Session, sessionErr
UserMessage: sessionErr.Error(),
}
}
err := a.cfg.StreamWriter.EmitAuditEvent(ctx, event)
if err != nil {
return trace.Wrap(err)
}
return nil
a.emitAuditEvent(ctx, event)
}

// OnSessionEnd emits an audit event when database session ends.
func (a *Audit) OnSessionEnd(ctx context.Context, session Session) error {
err := a.cfg.StreamWriter.EmitAuditEvent(ctx, &events.DatabaseSessionEnd{
func (a *audit) OnSessionEnd(ctx context.Context, session *Session) {
a.emitAuditEvent(ctx, &events.DatabaseSessionEnd{
Metadata: events.Metadata{
Type: libevents.DatabaseSessionEndEvent,
Code: libevents.DatabaseSessionEndCode,
@@ -128,15 +138,11 @@ func (a *Audit) OnSessionEnd(ctx context.Context, session Session) error {
DatabaseUser: session.DatabaseUser,
},
})
if err != nil {
return trace.Wrap(err)
}
return nil
}

// OnQuery emits an audit event when a database query is executed.
func (a *Audit) OnQuery(ctx context.Context, session Session, query string) error {
err := a.cfg.StreamWriter.EmitAuditEvent(ctx, &events.DatabaseSessionQuery{
func (a *audit) OnQuery(ctx context.Context, session *Session, query string, parameters ...string) {
a.emitAuditEvent(ctx, &events.DatabaseSessionQuery{
Metadata: events.Metadata{
Type: libevents.DatabaseSessionQueryEvent,
Code: libevents.DatabaseSessionQueryCode,
@@ -157,10 +163,13 @@ func (a *Audit) OnQuery(ctx context.Context, session Session, query string) erro
DatabaseName: session.DatabaseName,
DatabaseUser: session.DatabaseUser,
},
DatabaseQuery: query,
DatabaseQuery: query,
DatabaseQueryParameters: parameters,
})
if err != nil {
return trace.Wrap(err)
}

func (a *audit) emitAuditEvent(ctx context.Context, event events.AuditEvent) {
if err := a.cfg.Emitter.EmitAuditEvent(ctx, event); err != nil {
a.log.WithError(err).Errorf("Failed to emit audit event: %v.", event)
}
return nil
}
2 changes: 2 additions & 0 deletions lib/srv/db/common/session.go
Original file line number Diff line number Diff line change
@@ -46,6 +46,8 @@ type Session struct {
StartupParameters map[string]string
// Log is the logger with session specific fields.
Log logrus.FieldLogger
// Statements is the session's prepared statements cache.
Statements *StatementsCache
}

// String returns string representation of the session parameters.
141 changes: 141 additions & 0 deletions lib/srv/db/common/statements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package common

import (
"sync"

"github.com/gravitational/trace"
)

// StatementsCache contains prepared statements client executes during a
// database session.
//
// Currently only used to support Postgres extended protocol message flow
// to capture prepared statement queries in the audit log:
//
// https://www.postgresql.org/docs/10/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
type StatementsCache struct {
// cache maps prepared statement name to the statement itself.
cache map[string]*Statement
// mu is used to synchronize cache access.
mu sync.RWMutex
}

// Statement represents a prepared statement.
type Statement struct {
// Name is the prepared statement name.
//
// Can be empty (Postgres "unnamed statement").
Name string
// Query is the statement query string.
Query string
// Portals contains "destination portals" that bind prepared statement to
// parameters.
//
// In Postgres extended query protocol, clients execute these "portals"
// and not prepared statements directly.
Portals map[string]*Portal
}

// Portal represents a destination portal that binds a prepared statement
// to parameters.
type Portal struct {
// Name is the portal name.
//
// Can be empty (Postgres "unnamed portal").
Name string
// Query is the prepared statement query string.
Query string
// Parameters are the query parameters.
Parameters []string
}

// NewStatementsCache returns a new instance of prepared statements cache.
func NewStatementsCache() *StatementsCache {
return &StatementsCache{cache: make(map[string]*Statement)}
}

// Save adds the provided prepared statement information to the cache.
func (s *StatementsCache) Save(statementName, query string) {
s.mu.Lock()
defer s.mu.Unlock()
s.cache[statementName] = &Statement{
Name: statementName,
Query: query,
Portals: make(map[string]*Portal),
}
}

// Get returns the specified prepared statement from the cache.
func (s *StatementsCache) Get(statementName string) (*Statement, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if statement, ok := s.cache[statementName]; ok {
return statement, nil
}
return nil, trace.NotFound("prepared statement %q is not in cache", statementName)
}

// GetPortal returns the specified destination portal from the cache.
func (s *StatementsCache) GetPortal(portalName string) (*Portal, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, statement := range s.cache {
if portal, ok := statement.Portals[portalName]; ok {
return portal, nil
}
}
return nil, trace.NotFound("destination portal %q is not in cache", portalName)
}

// Bind adds the provided destination portal to the cache.
func (s *StatementsCache) Bind(statementName, portalName string, parameters ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for name, statement := range s.cache {
if name == statementName {
s.cache[name].Portals[portalName] = &Portal{
Name: portalName,
Query: statement.Query,
Parameters: parameters,
}
return nil
}
}
return trace.NotFound("prepared statement %q is not in cache", statementName)
}

// Remove removes the specified prepared statement from the cache, along with
// all its destination portals.
func (s *StatementsCache) Remove(statementName string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.cache, statementName)
}

// RemovePortal removes the specified destination portal from the cache.
func (s *StatementsCache) RemovePortal(portalName string) {
s.mu.Lock()
defer s.mu.Unlock()
for statementName, statement := range s.cache {
if _, ok := statement.Portals[portalName]; ok {
delete(s.cache[statementName].Portals, portalName)
return
}
}
}
148 changes: 148 additions & 0 deletions lib/srv/db/common/statements_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package common

import (
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

// TestStatementsCache verifies functionality of the cache that holds per-session
// prepared statements and their parameters.
func TestStatementsCache(t *testing.T) {
cache := NewStatementsCache()

// Full parse/bind/execute flow with an unnamed statement/portal.
cache.Save(testUnnamedStatement1.Name, testUnnamedStatement1.Query)
cache.Bind(testUnnamedStatement1.Name, testUnnamedPortal1.Name, testUnnamedPortal1.Parameters...)

statement, err := cache.Get(unnamedStatement)
require.NoError(t, err)
require.Equal(t, testUnnamedStatement1, statement)

portal, err := cache.GetPortal(unnamedPortal)
require.NoError(t, err)
require.Equal(t, testUnnamedPortal1, portal)

// Make sure another unnamed statement replaces the previous one.
cache.Save(testUnnamedStatement2.Name, testUnnamedStatement2.Query)
cache.Bind(testUnnamedStatement2.Name, testUnnamedPortal2.Name, testUnnamedPortal2.Parameters...)

statement, err = cache.Get(unnamedStatement)
require.NoError(t, err)
require.Equal(t, testUnnamedStatement2, statement)

portal, err = cache.GetPortal(unnamedPortal)
require.NoError(t, err)
require.Equal(t, testUnnamedPortal2, portal)

// Create a named statement and a couple of destination portals.
cache.Save(testStatement.Name, testStatement.Query)
cache.Bind(testStatement.Name, testPortal1.Name, testPortal1.Parameters...)
cache.Bind(testStatement.Name, testPortal2.Name, testPortal2.Parameters...)

statement, err = cache.Get(testStatement.Name)
require.NoError(t, err)
require.Equal(t, testStatement, statement)

portal1, err := cache.GetPortal(testPortal1.Name)
require.NoError(t, err)
require.Equal(t, testPortal1, portal1)

portal2, err := cache.GetPortal(testPortal2.Name)
require.NoError(t, err)
require.Equal(t, testPortal2, portal2)

// Try to get a couple non-existent statements/portals.
_, err = cache.Get("unknown")
require.IsType(t, trace.NotFound(""), err)

_, err = cache.GetPortal("unknown")
require.IsType(t, trace.NotFound(""), err)

// Close a portal.
cache.RemovePortal(testPortal1.Name)

_, err = cache.GetPortal(testPortal1.Name)
require.IsType(t, trace.NotFound(""), err)

// Close a statement and make sure its portal is gone as well.
cache.Remove(testStatement.Name)

_, err = cache.Get(testStatement.Name)
require.IsType(t, trace.NotFound(""), err)

_, err = cache.GetPortal(testPortal2.Name)
require.IsType(t, trace.NotFound(""), err)
}

const (
unnamedStatement = ""
unnamedPortal = ""
)

var (
testQuery1 = "select * from test"
testUnnamedPortal1 = &Portal{
Name: unnamedPortal,
Query: testQuery1,
Parameters: []string{},
}
testUnnamedStatement1 = &Statement{
Name: unnamedStatement,
Query: testQuery1,
Portals: map[string]*Portal{
unnamedPortal: testUnnamedPortal1,
},
}

testQuery2 = "select * from test where id = $1"
testUnnamedPortal2 = &Portal{
Name: unnamedPortal,
Query: testQuery2,
Parameters: []string{"123"},
}
testUnnamedStatement2 = &Statement{
Name: unnamedStatement,
Query: testQuery2,
Portals: map[string]*Portal{
unnamedPortal: testUnnamedPortal2,
},
}

testQuery3 = "update test set value = $1 where id = $2"
testPortal1 = &Portal{
Name: "P_1",
Query: testQuery3,
Parameters: []string{"abc", "123"},
}
testPortal2 = &Portal{
Name: "P_2",
Query: testQuery3,
Parameters: []string{"def", "456"},
}
testStatement = &Statement{
Name: "S_1",
Query: testQuery3,
Portals: map[string]*Portal{
"P_1": testPortal1,
"P_2": testPortal2,
},
}
)
23 changes: 5 additions & 18 deletions lib/srv/db/mysql/engine.go
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ type Engine struct {
// Auth handles database access authentication.
Auth *common.Auth
// Audit emits database access audit events.
Audit *common.Audit
Audit common.Audit
// Context is the database server close context.
Context context.Context
// Clock is the clock interface.
@@ -91,16 +91,8 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
if err != nil {
return trace.Wrap(err)
}
err = e.Audit.OnSessionStart(e.Context, *sessionCtx, nil)
if err != nil {
return trace.Wrap(err)
}
defer func() {
err := e.Audit.OnSessionEnd(e.Context, *sessionCtx)
if err != nil {
e.Log.WithError(err).Error("Failed to emit audit event.")
}
}()
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)
// Copy between the connections.
clientErrCh := make(chan error, 1)
serverErrCh := make(chan error, 1)
@@ -141,9 +133,7 @@ func (e *Engine) checkAccess(sessionCtx *common.Session) error {
&services.DatabaseLabelsMatcher{Labels: sessionCtx.Server.GetAllLabels()},
&services.DatabaseUserMatcher{User: sessionCtx.DatabaseUser})
if err != nil {
if err := e.Audit.OnSessionStart(e.Context, *sessionCtx, err); err != nil {
e.Log.WithError(err).Error("Failed to emit audit event.")
}
e.Audit.OnSessionStart(e.Context, sessionCtx, err)
return trace.Wrap(err)
}
return nil
@@ -194,10 +184,7 @@ func (e *Engine) receiveFromClient(clientConn, serverConn net.Conn, clientErrCh
}
switch pkt := packet.(type) {
case *protocol.Query:
err := e.Audit.OnQuery(e.Context, *sessionCtx, pkt.Query())
if err != nil {
log.WithError(err).Error("Failed to emit audit event.")
}
e.Audit.OnQuery(e.Context, sessionCtx, pkt.Query())
case *protocol.Quit:
clientErrCh <- nil
return
111 changes: 95 additions & 16 deletions lib/srv/db/postgres/engine.go
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ type Engine struct {
// Auth handles database access authentication.
Auth *common.Auth
// Audit emits database access audit events.
Audit *common.Audit
Audit common.Audit
// Context is the database server close context.
Context context.Context
// Clock is the clock interface.
@@ -111,16 +111,8 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
}
// At this point Postgres client should be ready to start sending
// messages: this is where psql prompt appears on the other side.
err = e.Audit.OnSessionStart(e.Context, *sessionCtx, nil)
if err != nil {
return trace.Wrap(err)
}
defer func() {
err := e.Audit.OnSessionEnd(e.Context, *sessionCtx)
if err != nil {
e.Log.WithError(err).Error("Failed to emit audit event.")
}
}()
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)
// Reconstruct pgconn.PgConn from hijacked connection for easier access
// to its utility methods (such as Close).
serverConn, err := pgconn.Construct(hijackedConn)
@@ -192,9 +184,7 @@ func (e *Engine) checkAccess(sessionCtx *common.Session) error {
&services.DatabaseUserMatcher{User: sessionCtx.DatabaseUser},
&services.DatabaseNameMatcher{Name: sessionCtx.DatabaseName})
if err != nil {
if err := e.Audit.OnSessionStart(e.Context, *sessionCtx, err); err != nil {
e.Log.WithError(err).Error("Failed to emit audit event.")
}
e.Audit.OnSessionStart(e.Context, sessionCtx, err)
return trace.Wrap(err)
}
return nil
@@ -274,9 +264,44 @@ func (e *Engine) receiveFromClient(client *pgproto3.Backend, server *pgproto3.Fr
log.Debugf("Received client message: %#v.", message)
switch msg := message.(type) {
case *pgproto3.Query:
err := e.Audit.OnQuery(e.Context, *sessionCtx, msg.String)
// Query message indicates the client is executing a simple query.
e.Audit.OnQuery(e.Context, sessionCtx, msg.String)
case *pgproto3.Parse:
// Parse message is a start of the extended query protocol which
// prepares parameterized query for execution. It is never used
// by psql, mostly by various GUI clients and programs.
// https://www.postgresql.org/docs/10/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
sessionCtx.Statements.Save(msg.Name, msg.Query)
case *pgproto3.Bind:
// Bind message readies existing prepared statement (created when
// Parse message is received) for execution into what Postgres
// calls a "destination portal", optionally binding it with
// parameters (for parameterized queries).
err := sessionCtx.Statements.Bind(
msg.PreparedStatement,
msg.DestinationPortal,
getBindParameters(msg)...)
if err != nil {
log.WithError(err).Error("Failed to emit audit event.")
log.WithError(err).Warnf("Failed to bind prepared statement %#v.", msg)
}
case *pgproto3.Execute:
// Execute message indicates the client is executing the previously
// parsed and bound prepared statement i.e. the "portal". This is
// where we emit the query audit event.
portal, err := sessionCtx.Statements.GetPortal(msg.Portal)
if err != nil {
log.WithError(err).Warnf("Failed to find destination portal %#v.", msg)
} else {
e.Audit.OnQuery(e.Context, sessionCtx, portal.Query, portal.Parameters...)
}
case *pgproto3.Close:
// Close message closes the specified prepared statement or portal.
// Remove respective object from the cache.
switch msg.ObjectType {
case closeTypePreparedStatement:
sessionCtx.Statements.Remove(msg.Name)
case closeTypeDestinationPortal:
sessionCtx.Statements.RemovePortal(msg.Name)
}
case *pgproto3.Terminate:
clientErrCh <- nil
@@ -362,3 +387,57 @@ func (e *Engine) getConnectConfig(ctx context.Context, sessionCtx *common.Sessio
}
return config, nil
}

// getBindParameters converts prepared statement parameters from the Postgres
// wire protocol Bind message into their string representations for including
// in the audit log.
func getBindParameters(msg *pgproto3.Bind) (parameters []string) {
// Each parameter can be either a text or a binary which is determined
// by "parameter format codes" in the Bind message (0 - text, 1 - binary).
//
// Be a bit paranoid and make sure that number of format codes matches the
// number of parameters, or there are no format codes in which case all
// parameters will be text.
if len(msg.ParameterFormatCodes) != 0 && len(msg.ParameterFormatCodes) != len(msg.Parameters) {
logrus.Warnf("Postgres parameter format codes and parameters don't match: %#v.", msg)
return parameters
}
for i, p := range msg.Parameters {
// According to Bind message documentation, if there are no parameter
// format codes, it may mean that either there are no parameters, or
// that all parameters use default text format.
if len(msg.ParameterFormatCodes) == 0 {
parameters = append(parameters, string(p))
continue
}
switch msg.ParameterFormatCodes[i] {
case parameterFormatCodeText:
// Text parameters can just be converted to their string
// representation.
parameters = append(parameters, string(p))
case parameterFormatCodeBinary:
// For binary parameters, just put a placeholder to avoid
// spamming the audit log with unreadable info.
parameters = append(parameters, "<binary>")
default:
// Should never happen but...
logrus.Warnf("Unknown Postgres parameter format code: %#v.", msg)
parameters = append(parameters, "<unknown>")
}
}
return parameters
}

const (
// parameterFormatCodeText indicates that this is a text query parameter.
parameterFormatCodeText = 0
// parameterFormatCodeBinary indicates that this is a binary query parameter.
parameterFormatCodeBinary = 1

// closeTypePreparedStatement indicates that a prepared statement is being
// closed by the Close message.
closeTypePreparedStatement = 'S'
// closeTypeDestinationPortal indicates that a destination portal is being
// closed by the Close message.
closeTypeDestinationPortal = 'P'
)
22 changes: 17 additions & 5 deletions lib/srv/db/postgres/test.go
Original file line number Diff line number Diff line change
@@ -143,16 +143,28 @@ func (s *TestServer) handleConnection(conn net.Conn) error {
return trace.Wrap(err)
}
s.log.Debugf("Received %#v.", message)
switch msg := message.(type) {
switch message.(type) {
case *pgproto3.Query:
err := s.handleQuery(client, msg)
if err != nil {
if err := s.handleQuery(client); err != nil {
s.log.WithError(err).Error("Failed to handle query.")
}
// Following messages are for handling Postgres extended query
// protocol flow used by prepared statements.
case *pgproto3.Parse:
// Parse prepares the statement.
case *pgproto3.Bind:
// Bind binds prepared statement with parameters.
case *pgproto3.Describe:
case *pgproto3.Sync:
case *pgproto3.Execute:
// Execute executes prepared statement.
if err := s.handleQuery(client); err != nil {
s.log.WithError(err).Error("Failed to handle query.")
}
case *pgproto3.Terminate:
return nil
default:
return trace.BadParameter("unsupported message %#v", msg)
return trace.BadParameter("unsupported message %#v", message)
}
}
}
@@ -195,7 +207,7 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend) error {
return nil
}

func (s *TestServer) handleQuery(client *pgproto3.Backend, query *pgproto3.Query) error {
func (s *TestServer) handleQuery(client *pgproto3.Backend) error {
atomic.AddUint32(&s.queryCount, 1)
messages := []pgproto3.BackendMessage{
&pgproto3.RowDescription{Fields: TestQueryResponse.FieldDescriptions},
5 changes: 2 additions & 3 deletions lib/srv/db/proxyserver.go
Original file line number Diff line number Diff line change
@@ -137,8 +137,7 @@ func (s *ProxyServer) Serve(listener net.Listener) error {
defer clientConn.Close()
err := proxy.HandleConnection(s.closeCtx, clientConn)
if err != nil {
s.log.Errorf("Failed to handle client connection: %v.",
trace.DebugReport(err))
s.log.WithError(err).Warn("Failed to handle client connection.")
}
}()
}
@@ -260,7 +259,7 @@ func (s *ProxyServer) Proxy(ctx context.Context, clientConn, serviceConn io.Read
for i := 0; i < 2; i++ {
select {
case err := <-errCh:
if err != nil && err != io.EOF && !strings.Contains(err.Error(), teleport.UseOfClosedNetworkConnection) {
if err != nil && !trace.IsEOF(err) && !strings.Contains(err.Error(), teleport.UseOfClosedNetworkConnection) {
s.log.WithError(err).Warn("Connection problem.")
errs = append(errs, err)
}
13 changes: 11 additions & 2 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
@@ -57,6 +57,8 @@ type Config struct {
AccessPoint auth.AccessPoint
// StreamEmitter is a non-blocking audit events emitter.
StreamEmitter events.StreamEmitter
// NewAudit allows to override audit logger in tests.
NewAudit NewAuditFn
// TLSConfig is the *tls.Config for this server.
TLSConfig *tls.Config
// Authorizer is used to authorize requests coming from proxy.
@@ -73,6 +75,9 @@ type Config struct {
OnHeartbeat func(error)
}

// NewAuditFn defines a function that creates an audit logger.
type NewAuditFn func(common.AuditConfig) (common.Audit, error)

// CheckAndSetDefaults makes sure the configuration has the minimum required
// to function.
func (c *Config) CheckAndSetDefaults(ctx context.Context) error {
@@ -91,6 +96,9 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) error {
if c.StreamEmitter == nil {
return trace.BadParameter("missing StreamEmitter")
}
if c.NewAudit == nil {
c.NewAudit = common.NewAudit
}
if c.TLSConfig == nil {
return trace.BadParameter("missing TLSConfig")
}
@@ -381,8 +389,8 @@ func (s *Server) dispatch(sessionCtx *common.Session, streamWriter events.Stream
if err != nil {
return nil, trace.Wrap(err)
}
audit, err := common.NewAudit(common.AuditConfig{
StreamWriter: streamWriter,
audit, err := s.cfg.NewAudit(common.AuditConfig{
Emitter: streamWriter,
})
if err != nil {
return nil, trace.Wrap(err)
@@ -447,6 +455,7 @@ func (s *Server) authorize(ctx context.Context) (*common.Session, error) {
DatabaseName: identity.RouteToDatabase.Database,
Checker: authContext.Checker,
StartupParameters: make(map[string]string),
Statements: common.NewStatementsCache(),
Log: s.log.WithFields(logrus.Fields{
"id": id,
"db": server.GetName(),

0 comments on commit 8230d6e

Please sign in to comment.