Skip to content
46 changes: 23 additions & 23 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1977,7 +1977,7 @@ func withDeniedDBLabels(labels types.Labels) roleOptFn {

// createUserAndRole creates Teleport user and role with specified names
// and allowed database users/names properties.
func (c *testContext) createUserAndRole(ctx context.Context, t *testing.T, userName, roleName string, dbUsers, dbNames []string, roleOpts ...roleOptFn) (types.User, types.Role) {
func (c *testContext) createUserAndRole(ctx context.Context, t testing.TB, userName, roleName string, dbUsers, dbNames []string, roleOpts ...roleOptFn) (types.User, types.Role) {
user, role, err := auth.CreateUserAndRole(c.tlsServer.Auth(), userName, []string{roleName}, nil)
require.NoError(t, err)
role.SetDatabaseUsers(types.Allow, dbUsers)
Expand All @@ -1991,7 +1991,7 @@ func (c *testContext) createUserAndRole(ctx context.Context, t *testing.T, userN
}

// makeTLSConfig returns tls configuration for the test's tls listener.
func (c *testContext) makeTLSConfig(t *testing.T) *tls.Config {
func (c *testContext) makeTLSConfig(t testing.TB) *tls.Config {
creds, err := cert.GenerateSelfSignedCert([]string{"localhost"}, nil)
require.NoError(t, err)
cert, err := tls.X509KeyPair(creds.Cert, creds.PrivateKey)
Expand Down Expand Up @@ -2028,7 +2028,7 @@ func init() {
SetShuffleFunc(ShuffleSort)
}

func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDatabaseOption) *testContext {
func setupTestContext(ctx context.Context, t testing.TB, withDatabases ...withDatabaseOption) *testContext {
testCtx := &testContext{
clusterName: "root.example.com",
hostID: uuid.New().String(),
Expand Down Expand Up @@ -2252,7 +2252,7 @@ func (p *agentParams) setDefaults(c *testContext) {
}
}

func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p agentParams) *Server {
func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p agentParams) *Server {
p.setDefaults(c)

// Database service credentials.
Expand Down Expand Up @@ -2432,12 +2432,12 @@ func TestAccessClickHouse(t *testing.T) {
}
}

type withDatabaseOption func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database
type withDatabaseOption func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database

type databaseOption func(*types.DatabaseV3)

func withSelfHostedPostgres(name string, dbOpts ...databaseOption) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
postgresServer, err := postgres.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2466,7 +2466,7 @@ func withSelfHostedPostgres(name string, dbOpts ...databaseOption) withDatabaseO
}

func withRDSPostgres(name, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
postgresServer, err := postgres.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2499,7 +2499,7 @@ func withRDSPostgres(name, authToken string) withDatabaseOption {
}

func withRedshiftPostgres(name, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
postgresServer, err := postgres.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2533,7 +2533,7 @@ func withRedshiftPostgres(name, authToken string) withDatabaseOption {
}

func withCloudSQLPostgres(name, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
postgresServer, err := postgres.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2570,7 +2570,7 @@ func withCloudSQLPostgres(name, authToken string) withDatabaseOption {
}

func withAzurePostgres(name, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
postgresServer, err := postgres.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2603,7 +2603,7 @@ func withAzurePostgres(name, authToken string) withDatabaseOption {
}

func withSelfHostedMySQL(name string, opts ...mysql.TestServerOption) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2631,7 +2631,7 @@ func withSelfHostedMySQL(name string, opts ...mysql.TestServerOption) withDataba
}

func withRDSMySQL(name, authUser, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2663,7 +2663,7 @@ func withRDSMySQL(name, authUser, authToken string) withDatabaseOption {
}

func withCloudSQLMySQL(name, authUser, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2701,7 +2701,7 @@ func withCloudSQLMySQL(name, authUser, authToken string) withDatabaseOption {
// withCloudSQLMySQLTLS creates a test MySQL server that simulates GCP Cloud SQL
// and requires client authentication using an ephemeral client certificate.
func withCloudSQLMySQLTLS(name, authUser, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2739,7 +2739,7 @@ func withCloudSQLMySQLTLS(name, authUser, authToken string) withDatabaseOption {
}

func withAzureMySQL(name, authUser, authToken string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2771,7 +2771,7 @@ func withAzureMySQL(name, authUser, authToken string) withDatabaseOption {
}

func withAtlasMongo(name, authUser, authSession string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
mongoServer, err := mongodb.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2805,7 +2805,7 @@ func withAtlasMongo(name, authUser, authSession string) withDatabaseOption {
}

func withSelfHostedMongo(name string, opts ...mongodb.TestServerOption) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
mongoServer, err := mongodb.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand All @@ -2831,7 +2831,7 @@ func withSelfHostedMongo(name string, opts ...mongodb.TestServerOption) withData
}

func withSelfHostedRedis(name string, opts ...redis.TestServerOption) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
redisServer, err := redis.NewTestServer(t, common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand All @@ -2856,7 +2856,7 @@ func withSelfHostedRedis(name string, opts ...redis.TestServerOption) withDataba
}

func withSQLServer(name string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
sqlServer, err := sqlserver.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand All @@ -2880,7 +2880,7 @@ func withSQLServer(name string) withDatabaseOption {
}

func withClickhouseNative(name string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
server, err := clickhouse.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand All @@ -2906,7 +2906,7 @@ func withClickhouseNative(name string) withDatabaseOption {
}

func withClickhouseHTTP(name string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
server, err := clickhouse.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand All @@ -2930,7 +2930,7 @@ func withClickhouseHTTP(name string) withDatabaseOption {
}

func withElastiCacheRedis(name string, token, engineVersion string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
redisServer, err := redis.NewTestServer(t, common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down Expand Up @@ -2967,7 +2967,7 @@ func withElastiCacheRedis(name string, token, engineVersion string) withDatabase
}

func withAzureRedis(name string, token string) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
redisServer, err := redis.NewTestServer(t, common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down
94 changes: 94 additions & 0 deletions lib/srv/db/benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2023 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package db

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
)

/*
$ go test ./lib/srv/db -bench=. -run=^$ -benchtime=3x
goos: darwin
goarch: arm64
pkg: github.com/gravitational/teleport/lib/srv/db
BenchmarkPostgresReadLargeTable/size=11-10 3 286618514 ns/op
BenchmarkPostgresReadLargeTable/size=20-10 3 253457917 ns/op
BenchmarkPostgresReadLargeTable/size=100-10 3 222804292 ns/op
BenchmarkPostgresReadLargeTable/size=1000-10 3 216612764 ns/op
BenchmarkPostgresReadLargeTable/size=2000-10 3 214121861 ns/op
BenchmarkPostgresReadLargeTable/size=8000-10 3 215046472 ns/op
*/
// BenchmarkPostgresReadLargeTable is a benchmark for read-heavy usage of Postgres.
// Depending on the message size we may get different performance, due to the way the respective engine is written.
func BenchmarkPostgresReadLargeTable(b *testing.B) {
b.StopTimer()
ctx := context.Background()
testCtx := setupTestContext(ctx, b, withSelfHostedPostgres("postgres", func(db *types.DatabaseV3) {
db.SetStaticLabels(map[string]string{"foo": "bar"})
}))
go testCtx.startHandlingConnections()

user := "alice"
role := "admin"
allow := []string{types.Wildcard}

// Create user/role with the requested permissions.
testCtx.createUserAndRole(ctx, b, user, role, allow, allow)
Comment thread
Tener marked this conversation as resolved.
for _, messageSize := range []int{11, 20, 100, 1000, 2000, 8000} {

// connect to the database
pgConn, err := testCtx.postgresClient(ctx, user, "postgres", "postgres", "postgres")
require.NoError(b, err)

// total bytes to be transmitted, approximate.
const totalBytes = 100 * 1024 * 1024
// calculate the number of messages required to reach totalBytes of transferred data.
rowCount := totalBytes / messageSize

// run first query without timer. the server will cache the message.
_, err = pgConn.Exec(ctx, fmt.Sprintf("SELECT * FROM bench_%v LIMIT %v", messageSize, rowCount)).ReadAll()
require.NoError(b, err)

b.Run(fmt.Sprintf("size=%v", messageSize), func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Execute a query, count results.
q := pgConn.Exec(ctx, fmt.Sprintf("SELECT * FROM bench_%v LIMIT %v", messageSize, rowCount))

// pgConn.Exec can potentially execute multiple SQL queries.
// the outer loop is a query loop, the inner loop is for individual results.
rows := 0
for q.NextResult() {
rr := q.ResultReader()
for rr.NextRow() {
rows++
}
}

require.NoError(b, q.Close())
require.Equal(b, rowCount, rows)
}
})

// Disconnect.
err = pgConn.Close(ctx)
require.NoError(b, err)
}
}
2 changes: 1 addition & 1 deletion lib/srv/db/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func TestEventCassandra(t *testing.T) {
}

func withCassandra(name string, opts ...cassandra.TestServerOption) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
cassandraServer, err := cassandra.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/dynamodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func TestAuditDynamoDB(t *testing.T) {
}

func withDynamoDB(name string, opts ...dynamodb.TestServerOption) withDatabaseOption {
return func(t *testing.T, _ context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, _ context.Context, testCtx *testContext) types.Database {
config := common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/elasticsearch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func TestAuditElasticsearch(t *testing.T) {
}

func withElasticsearch(name string, opts ...elasticsearch.TestServerOption) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
ElasticsearchServer, err := elasticsearch.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/opensearch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func TestAuditOpenSearch(t *testing.T) {
}

func withOpenSearch(name string, opts ...opensearch.TestServerOption) withDatabaseOption {
return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database {
return func(t testing.TB, ctx context.Context, testCtx *testContext) types.Database {
OpenSearchServer, err := opensearch.NewTestServer(common.TestServerConfig{
Name: name,
AuthClient: testCtx.authClient,
Expand Down
Loading