diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index b50810b5451ca..37ae3c4ad098e 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -1977,7 +1977,7 @@ func withDeniedDBLabels(labels types.Labels) roleOptFn { // createUserAndRole creates Teleport user and role with specified names // and allowed database users/names properties. -func (c *testContext) createUserAndRole(ctx context.Context, t *testing.T, userName, roleName string, dbUsers, dbNames []string, roleOpts ...roleOptFn) (types.User, types.Role) { +func (c *testContext) createUserAndRole(ctx context.Context, t testing.TB, userName, roleName string, dbUsers, dbNames []string, roleOpts ...roleOptFn) (types.User, types.Role) { user, role, err := auth.CreateUserAndRole(c.tlsServer.Auth(), userName, []string{roleName}, nil) require.NoError(t, err) role.SetDatabaseUsers(types.Allow, dbUsers) @@ -1991,7 +1991,7 @@ func (c *testContext) createUserAndRole(ctx context.Context, t *testing.T, userN } // makeTLSConfig returns tls configuration for the test's tls listener. -func (c *testContext) makeTLSConfig(t *testing.T) *tls.Config { +func (c *testContext) makeTLSConfig(t testing.TB) *tls.Config { creds, err := cert.GenerateSelfSignedCert([]string{"localhost"}, nil) require.NoError(t, err) cert, err := tls.X509KeyPair(creds.Cert, creds.PrivateKey) @@ -2028,7 +2028,7 @@ func init() { SetShuffleFunc(ShuffleSort) } -func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDatabaseOption) *testContext { +func setupTestContext(ctx context.Context, t testing.TB, withDatabases ...withDatabaseOption) *testContext { testCtx := &testContext{ clusterName: "root.example.com", hostID: uuid.New().String(), @@ -2252,7 +2252,7 @@ func (p *agentParams) setDefaults(c *testContext) { } } -func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p agentParams) *Server { +func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p agentParams) *Server { p.setDefaults(c) // Database service credentials. @@ -2432,12 +2432,12 @@ func TestAccessClickHouse(t *testing.T) { } } -type withDatabaseOption func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database +type withDatabaseOption func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database type databaseOption func(*types.DatabaseV3) func withSelfHostedPostgres(name string, dbOpts ...databaseOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { postgresServer, err := postgres.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2466,7 +2466,7 @@ func withSelfHostedPostgres(name string, dbOpts ...databaseOption) withDatabaseO } func withRDSPostgres(name, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { postgresServer, err := postgres.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2499,7 +2499,7 @@ func withRDSPostgres(name, authToken string) withDatabaseOption { } func withRedshiftPostgres(name, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { postgresServer, err := postgres.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2533,7 +2533,7 @@ func withRedshiftPostgres(name, authToken string) withDatabaseOption { } func withCloudSQLPostgres(name, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { postgresServer, err := postgres.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2570,7 +2570,7 @@ func withCloudSQLPostgres(name, authToken string) withDatabaseOption { } func withAzurePostgres(name, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { postgresServer, err := postgres.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2603,7 +2603,7 @@ func withAzurePostgres(name, authToken string) withDatabaseOption { } func withSelfHostedMySQL(name string, opts ...mysql.TestServerOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2631,7 +2631,7 @@ func withSelfHostedMySQL(name string, opts ...mysql.TestServerOption) withDataba } func withRDSMySQL(name, authUser, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2663,7 +2663,7 @@ func withRDSMySQL(name, authUser, authToken string) withDatabaseOption { } func withCloudSQLMySQL(name, authUser, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2701,7 +2701,7 @@ func withCloudSQLMySQL(name, authUser, authToken string) withDatabaseOption { // withCloudSQLMySQLTLS creates a test MySQL server that simulates GCP Cloud SQL // and requires client authentication using an ephemeral client certificate. func withCloudSQLMySQLTLS(name, authUser, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2739,7 +2739,7 @@ func withCloudSQLMySQLTLS(name, authUser, authToken string) withDatabaseOption { } func withAzureMySQL(name, authUser, authToken string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2771,7 +2771,7 @@ func withAzureMySQL(name, authUser, authToken string) withDatabaseOption { } func withAtlasMongo(name, authUser, authSession string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { mongoServer, err := mongodb.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2805,7 +2805,7 @@ func withAtlasMongo(name, authUser, authSession string) withDatabaseOption { } func withSelfHostedMongo(name string, opts ...mongodb.TestServerOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { mongoServer, err := mongodb.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2831,7 +2831,7 @@ func withSelfHostedMongo(name string, opts ...mongodb.TestServerOption) withData } func withSelfHostedRedis(name string, opts ...redis.TestServerOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { redisServer, err := redis.NewTestServer(t, common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2856,7 +2856,7 @@ func withSelfHostedRedis(name string, opts ...redis.TestServerOption) withDataba } func withSQLServer(name string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { sqlServer, err := sqlserver.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2880,7 +2880,7 @@ func withSQLServer(name string) withDatabaseOption { } func withClickhouseNative(name string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { server, err := clickhouse.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2906,7 +2906,7 @@ func withClickhouseNative(name string) withDatabaseOption { } func withClickhouseHTTP(name string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { server, err := clickhouse.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2930,7 +2930,7 @@ func withClickhouseHTTP(name string) withDatabaseOption { } func withElastiCacheRedis(name string, token, engineVersion string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { redisServer, err := redis.NewTestServer(t, common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, @@ -2967,7 +2967,7 @@ func withElastiCacheRedis(name string, token, engineVersion string) withDatabase } func withAzureRedis(name string, token string) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { redisServer, err := redis.NewTestServer(t, common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, diff --git a/lib/srv/db/benchmark_test.go b/lib/srv/db/benchmark_test.go new file mode 100644 index 0000000000000..aed2c0a650f0b --- /dev/null +++ b/lib/srv/db/benchmark_test.go @@ -0,0 +1,94 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +/* +$ go test ./lib/srv/db -bench=. -run=^$ -benchtime=3x +goos: darwin +goarch: arm64 +pkg: github.com/gravitational/teleport/lib/srv/db +BenchmarkPostgresReadLargeTable/size=11-10 3 286618514 ns/op +BenchmarkPostgresReadLargeTable/size=20-10 3 253457917 ns/op +BenchmarkPostgresReadLargeTable/size=100-10 3 222804292 ns/op +BenchmarkPostgresReadLargeTable/size=1000-10 3 216612764 ns/op +BenchmarkPostgresReadLargeTable/size=2000-10 3 214121861 ns/op +BenchmarkPostgresReadLargeTable/size=8000-10 3 215046472 ns/op +*/ +// BenchmarkPostgresReadLargeTable is a benchmark for read-heavy usage of Postgres. +// Depending on the message size we may get different performance, due to the way the respective engine is written. +func BenchmarkPostgresReadLargeTable(b *testing.B) { + b.StopTimer() + ctx := context.Background() + testCtx := setupTestContext(ctx, b, withSelfHostedPostgres("postgres", func(db *types.DatabaseV3) { + db.SetStaticLabels(map[string]string{"foo": "bar"}) + })) + go testCtx.startHandlingConnections() + + user := "alice" + role := "admin" + allow := []string{types.Wildcard} + + // Create user/role with the requested permissions. + testCtx.createUserAndRole(ctx, b, user, role, allow, allow) + for _, messageSize := range []int{11, 20, 100, 1000, 2000, 8000} { + + // connect to the database + pgConn, err := testCtx.postgresClient(ctx, user, "postgres", "postgres", "postgres") + require.NoError(b, err) + + // total bytes to be transmitted, approximate. + const totalBytes = 100 * 1024 * 1024 + // calculate the number of messages required to reach totalBytes of transferred data. + rowCount := totalBytes / messageSize + + // run first query without timer. the server will cache the message. + _, err = pgConn.Exec(ctx, fmt.Sprintf("SELECT * FROM bench_%v LIMIT %v", messageSize, rowCount)).ReadAll() + require.NoError(b, err) + + b.Run(fmt.Sprintf("size=%v", messageSize), func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Execute a query, count results. + q := pgConn.Exec(ctx, fmt.Sprintf("SELECT * FROM bench_%v LIMIT %v", messageSize, rowCount)) + + // pgConn.Exec can potentially execute multiple SQL queries. + // the outer loop is a query loop, the inner loop is for individual results. + rows := 0 + for q.NextResult() { + rr := q.ResultReader() + for rr.NextRow() { + rows++ + } + } + + require.NoError(b, q.Close()) + require.Equal(b, rowCount, rows) + } + }) + + // Disconnect. + err = pgConn.Close(ctx) + require.NoError(b, err) + } +} diff --git a/lib/srv/db/cassandra_test.go b/lib/srv/db/cassandra_test.go index da68add4a8dff..1c75b3be61125 100644 --- a/lib/srv/db/cassandra_test.go +++ b/lib/srv/db/cassandra_test.go @@ -288,7 +288,7 @@ func TestEventCassandra(t *testing.T) { } func withCassandra(name string, opts ...cassandra.TestServerOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { cassandraServer, err := cassandra.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, diff --git a/lib/srv/db/dynamodb_test.go b/lib/srv/db/dynamodb_test.go index 369817f511895..e5fe75a793116 100644 --- a/lib/srv/db/dynamodb_test.go +++ b/lib/srv/db/dynamodb_test.go @@ -202,7 +202,7 @@ func TestAuditDynamoDB(t *testing.T) { } func withDynamoDB(name string, opts ...dynamodb.TestServerOption) withDatabaseOption { - return func(t *testing.T, _ context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, _ context.Context, testCtx *testContext) types.Database { config := common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, diff --git a/lib/srv/db/elasticsearch_test.go b/lib/srv/db/elasticsearch_test.go index f8330b6eb3944..274482f358c47 100644 --- a/lib/srv/db/elasticsearch_test.go +++ b/lib/srv/db/elasticsearch_test.go @@ -190,7 +190,7 @@ func TestAuditElasticsearch(t *testing.T) { } func withElasticsearch(name string, opts ...elasticsearch.TestServerOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { ElasticsearchServer, err := elasticsearch.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, diff --git a/lib/srv/db/opensearch_test.go b/lib/srv/db/opensearch_test.go index 6f08ea99599fe..0cf355cdb963a 100644 --- a/lib/srv/db/opensearch_test.go +++ b/lib/srv/db/opensearch_test.go @@ -200,7 +200,7 @@ func TestAuditOpenSearch(t *testing.T) { } func withOpenSearch(name string, opts ...opensearch.TestServerOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { OpenSearchServer, err := opensearch.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient, diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index 88faa1294a2d2..f76c63c7dafd2 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -58,10 +58,16 @@ type Engine struct { // cancelReq is a cancel request saved when a cancel request is received // instead of a startup message. cancelReq *pgproto3.CancelRequest + + // rawClientConn is raw, unwrapped network connection to the client + rawClientConn net.Conn + // rawServerConn is raw, unwrapped network connection to the server + rawServerConn net.Conn } // InitializeConnection initializes the client connection. func (e *Engine) InitializeConnection(clientConn net.Conn, sessionCtx *common.Session) error { + e.rawClientConn = clientConn e.client = pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) // The proxy is supposed to pass a startup message it received from @@ -138,6 +144,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio cancelAutoUserLease() return trace.Wrap(err) } + e.rawServerConn = hijackedConn.Conn // Release the auto-users semaphore now that we've successfully connected. cancelAutoUserLease() // Upon successful connect, indicate to the Postgres client that startup @@ -170,7 +177,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio clientErrCh := make(chan error, 1) serverErrCh := make(chan error, 1) go e.receiveFromClient(e.client, server, clientErrCh, sessionCtx) - go e.receiveFromServer(server, e.client, serverConn, serverErrCh, sessionCtx) + go e.receiveFromServer(serverConn, serverErrCh, sessionCtx) select { case err := <-clientErrCh: e.Log.WithError(err).Debug("Client done.") @@ -402,38 +409,54 @@ func (e *Engine) auditFuncCallMessage(session *common.Session, msg *pgproto3.Fun // receiveFromServer receives messages from the provided frontend (which // is connected to the database instance) and relays them back to the psql // or other client via the provided backend. -func (e *Engine) receiveFromServer(server *pgproto3.Frontend, client *pgproto3.Backend, serverConn *pgconn.PgConn, serverErrCh chan<- error, sessionCtx *common.Session) { +func (e *Engine) receiveFromServer(serverConn *pgconn.PgConn, serverErrCh chan<- error, sessionCtx *common.Session) { log := e.Log.WithField("from", "server") - ctr := common.GetMessagesFromServerMetric(sessionCtx.Database) - defer log.Debug("Stop receiving from server.") - for { - message, err := server.Receive() - if err != nil { - if serverConn.IsClosed() { - log.Debug("Server connection closed.") - serverErrCh <- nil + // parse and count the messages from the server in a separate goroutine, + // operating on a copy of the server message stream. the copy is arranged below. + copyReader, copyWriter := io.Pipe() + defer copyWriter.Close() + + go func() { + defer copyReader.Close() + + // server will never be used to write to server, + // which is why we pass io.Discard instead of e.rawServerConn + server := pgproto3.NewFrontend(pgproto3.NewChunkReader(copyReader), io.Discard) + + var count int64 + defer func() { + log.WithField("parsed_total", count).Debug("Stopped parsing messages from server.") + }() + + for { + message, err := server.Receive() + if err != nil { + if serverConn.IsClosed() { + log.Debug("Server connection closed.") + return + } + log.WithError(err).Error("Failed to receive message from server.") return } - log.WithError(err).Errorf("Failed to receive message from server.") - serverErrCh <- err - return - } - log.Tracef("Received server message: %#v.", message) - ctr.Inc() - // This is where we would plug in custom logic for particular - // messages received from the Postgres server (i.e. emitting - // an audit event), but for now just pass them along back to - // the client. - err = client.Send(message) - if err != nil { - log.WithError(err).Error("Failed to send message to client.") - serverErrCh <- err - return + count += 1 + ctr.Inc() + log.Tracef("Received server message: %#v.", message) } + }() + + // the messages are ultimately copied from e.rawServerConn to e.rawClientConn, + // but a copy of that message stream is written to a synchronous pipe, + // which is read by the analysis goroutine above. + total, err := io.Copy(e.rawClientConn, io.TeeReader(e.rawServerConn, copyWriter)) + if err != nil && !trace.IsConnectionProblem(trace.ConvertSystemError(err)) { + log.WithError(err).Warn("Server -> Client copy finished with unexpected error.") } + + serverErrCh <- trace.Wrap(err) + log.Debugf("Stopped receiving from server. Transferred %v bytes.", total) } // getConnectConfig returns config that can be used to connect to the diff --git a/lib/srv/db/postgres/test.go b/lib/srv/db/postgres/test.go index 58063afb0b761..be3ad75e5efd4 100644 --- a/lib/srv/db/postgres/test.go +++ b/lib/srv/db/postgres/test.go @@ -17,11 +17,14 @@ limitations under the License. package postgres import ( + "bytes" "context" + "crypto/rand" "crypto/tls" "fmt" "net" "regexp" + "strconv" "strings" "sync" "sync/atomic" @@ -90,6 +93,9 @@ type TestServer struct { pids map[uint32]*pidHandle // pidMu is a lock protecting nextPid and pids. pidMu sync.Mutex + + // mmCache caches multiMessage for reuse in benchmark + mmCache map[string]*multiMessage } // pidHandle represents a fake pid handle that can cancel operations in progress. @@ -142,6 +148,7 @@ func NewTestServer(config common.TestServerConfig) (svr *TestServer, err error) pids: make(map[uint32]*pidHandle), storedProcedures: make(map[string]string), userEventsCh: make(chan UserEvent, 100), + mmCache: make(map[string]*multiMessage), }, nil } @@ -331,6 +338,10 @@ func (s *TestServer) handleQuery(client *pgproto3.Backend, query string, pid uin return trace.Wrap(err) } } + if selectBenchmarkRe.MatchString(query) { + return trace.Wrap(s.handleBenchmarkQuery(query, client)) + } + messages := []pgproto3.BackendMessage{ &pgproto3.RowDescription{Fields: TestQueryResponse.FieldDescriptions}, &pgproto3.DataRow{Values: TestQueryResponse.Rows[0]}, @@ -364,6 +375,107 @@ func (s *TestServer) handleCreateStoredProcedure(query string) error { return nil } +// multiMessage wraps *pgproto3.DataRow and implements pgproto3.BackendMessage by writing multiple copies of this message in Encode. +type multiMessage struct { + singleMessage *pgproto3.DataRow + payload []byte +} + +func newMultiMessage(rowSize, repeats int) (*multiMessage, error) { + buf := make([]byte, rowSize) + _, err := rand.Read(buf) + if err != nil { + return nil, trace.Wrap(err) + } + message := &pgproto3.DataRow{Values: [][]byte{buf}} + encoded := message.Encode(nil) + payload := bytes.Repeat(encoded, repeats) + return &multiMessage{ + singleMessage: message, + payload: payload, + }, nil +} + +func (m *multiMessage) Decode(_ []byte) error { + return trace.NotImplemented("Decode is not implemented for multiMessage") +} + +func (m *multiMessage) Encode(dst []byte) []byte { + return append(dst, m.payload...) +} + +func (m *multiMessage) Backend() { +} + +var _ pgproto3.BackendMessage = (*multiMessage)(nil) + +func (s *TestServer) getMultiMessage(rowSize, repeats int) (*multiMessage, error) { + key := fmt.Sprintf("%v/%v", rowSize, repeats) + if mm, ok := s.mmCache[key]; ok { + return mm, nil + } + mm, err := newMultiMessage(rowSize, repeats) + if err != nil { + return nil, trace.Wrap(err) + } + s.mmCache[key] = mm + return mm, nil +} + +// handleBenchmarkQuery handles the query used for read benchmark. It will send a stream of messages of requested size and number. +func (s *TestServer) handleBenchmarkQuery(query string, client *pgproto3.Backend) error { + // parse benchmark parameters + matches := selectBenchmarkRe.FindStringSubmatch(query) + + messageSize, err := strconv.Atoi(matches[1]) + if err != nil { + return trace.Wrap(err) + } + // minimum message size is 11, corresponding to empty buffer transferred in a DataRow + if messageSize < 11 { + return trace.BadParameter("bad message size, must be at least 11, got %v", messageSize) + } + + repeats, err := strconv.Atoi(matches[2]) + if err != nil { + return trace.Wrap(err) + } + + mm, err := s.getMultiMessage(messageSize-11, repeats) + if err != nil { + return trace.Wrap(err) + } + + s.log.Debugf("Responding to query %q, will send %v messages of length %v, total length %v", query, repeats, len(mm.singleMessage.Encode(nil)), len(mm.payload)) + + // preamble + err = client.Send(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("dummy")}}}) + if err != nil { + return trace.Wrap(err) + } + + // send messages in bulk, which is fast. + err = client.Send(mm) + if err != nil { + return trace.Wrap(err) + } + + // epilogue + err = client.Send(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 100")}) + if err != nil { + return trace.Wrap(err) + } + + err = client.Send(&pgproto3.ReadyForQuery{}) + if err != nil { + return trace.Wrap(err) + } + + s.log.Debugf("Finished handling query %q", query) + + return nil +} + func (s *TestServer) handleActivateUser(client *pgproto3.Backend) error { // Expect Describe message. _, err := s.receiveDescribeMessage(client) @@ -709,3 +821,6 @@ const testSecretKey = 1234 // storedProcedureRe is the regex for capturing stored procedure name from its // creation query. var storedProcedureRe = regexp.MustCompile(`create or replace procedure (.+)\(`) + +// selectBenchmarkRe is the regex for capturing the parameters from the select query used for read benchmark. +var selectBenchmarkRe = regexp.MustCompile(`SELECT \* FROM bench\_(\d+) LIMIT (\d+)`) diff --git a/lib/srv/db/redis/test.go b/lib/srv/db/redis/test.go index 0272a295c1eb4..f458fca88991d 100644 --- a/lib/srv/db/redis/test.go +++ b/lib/srv/db/redis/test.go @@ -122,7 +122,7 @@ func TestServerPassword(password string) TestServerOption { } // NewTestServer returns a new instance of a test Redis server. -func NewTestServer(t *testing.T, config common.TestServerConfig, opts ...TestServerOption) (*TestServer, error) { +func NewTestServer(t testing.TB, config common.TestServerConfig, opts ...TestServerOption) (*TestServer, error) { tlsConfig, err := common.MakeTestServerTLSConfig(config) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/snowflake_test.go b/lib/srv/db/snowflake_test.go index 94d87976b1e9d..7984fa6cd95b6 100644 --- a/lib/srv/db/snowflake_test.go +++ b/lib/srv/db/snowflake_test.go @@ -244,7 +244,7 @@ func TestAuditSnowflake(t *testing.T) { }) t.Run("session ends event", func(t *testing.T) { - t.Skip() //TODO(jakule): Driver for some reason doesn't terminate the session. + t.Skip() // TODO(jakule): Driver for some reason doesn't terminate the session. // Closing connection should trigger session end event. err := dbConn.Close() require.NoError(t, err) @@ -354,7 +354,7 @@ func TestTokenSession(t *testing.T) { } func withSnowflake(name string, opts ...snowflake.TestServerOption) withDatabaseOption { - return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database { snowflakeServer, err := snowflake.NewTestServer(common.TestServerConfig{ Name: name, AuthClient: testCtx.authClient,