From 6146d6bb43070777d410f50f57b49e53df05d31b Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Mon, 15 Jul 2024 20:09:37 -0300 Subject: [PATCH 1/6] refactor(postgres): use pg_temp to store auto user procedures --- lib/srv/db/postgres/sql/activate-user.sql | 4 +- lib/srv/db/postgres/sql/deactivate-user.sql | 2 +- lib/srv/db/postgres/sql/delete-user.sql | 4 +- lib/srv/db/postgres/test.go | 134 +++++++++++--- lib/srv/db/postgres/users.go | 192 +++++++++++++------- 5 files changed, 240 insertions(+), 96 deletions(-) diff --git a/lib/srv/db/postgres/sql/activate-user.sql b/lib/srv/db/postgres/sql/activate-user.sql index 43ff2c282ebfb..8581af28dc7b4 100644 --- a/lib/srv/db/postgres/sql/activate-user.sql +++ b/lib/srv/db/postgres/sql/activate-user.sql @@ -1,4 +1,4 @@ -CREATE OR REPLACE PROCEDURE teleport_activate_user(username varchar, roles varchar[]) +CREATE OR REPLACE PROCEDURE pg_temp.teleport_activate_user(username varchar, roles varchar[]) LANGUAGE plpgsql AS $$ DECLARE @@ -25,7 +25,7 @@ BEGIN -- Otherwise reactivate the user, but first strip if of all roles to -- account for scenarios with left-over roles if database agent crashed -- and failed to cleanup upon session termination. - CALL teleport_deactivate_user(username); + CALL pg_temp.teleport_deactivate_user(username); EXECUTE FORMAT('ALTER USER %I WITH LOGIN', username); ELSE EXECUTE FORMAT('CREATE USER %I IN ROLE "teleport-auto-user"', username); diff --git a/lib/srv/db/postgres/sql/deactivate-user.sql b/lib/srv/db/postgres/sql/deactivate-user.sql index f344f21508929..f0119539c5e8a 100644 --- a/lib/srv/db/postgres/sql/deactivate-user.sql +++ b/lib/srv/db/postgres/sql/deactivate-user.sql @@ -1,4 +1,4 @@ -CREATE OR REPLACE PROCEDURE teleport_deactivate_user(username varchar) +CREATE OR REPLACE PROCEDURE pg_temp.teleport_deactivate_user(username varchar) LANGUAGE plpgsql AS $$ DECLARE diff --git a/lib/srv/db/postgres/sql/delete-user.sql b/lib/srv/db/postgres/sql/delete-user.sql index ff5a63a06b259..c2f3ea59d18cd 100644 --- a/lib/srv/db/postgres/sql/delete-user.sql +++ b/lib/srv/db/postgres/sql/delete-user.sql @@ -1,4 +1,4 @@ -CREATE OR REPLACE PROCEDURE teleport_delete_user(username varchar, inout state varchar default 'TP003') +CREATE OR REPLACE PROCEDURE pg_temp.teleport_delete_user(username varchar, inout state varchar default 'TP003') LANGUAGE plpgsql AS $$ DECLARE @@ -15,7 +15,7 @@ BEGIN state := 'TP004'; -- Drop user/role will fail if user has dependent objects. -- In this scenario, fallback into disabling the user. - CALL teleport_deactivate_user(username); + CALL pg_temp.teleport_deactivate_user(username); END; END IF; END;$$; diff --git a/lib/srv/db/postgres/test.go b/lib/srv/db/postgres/test.go index 04e4c79e0149f..5b46a7731efe2 100644 --- a/lib/srv/db/postgres/test.go +++ b/lib/srv/db/postgres/test.go @@ -83,7 +83,7 @@ type TestServer struct { // parametersCh receives startup message connection parameters. parametersCh chan map[string]string // storedProcedures are the stored procedures created on the server. - storedProcedures map[string]string + storedProcedures map[string]*storedProcedure // userEventsCh receives user activate/deactivate events. userEventsCh chan UserEvent // userPermissionEventsCh receives user permission change events. @@ -133,6 +133,12 @@ type UserPermissionEvent struct { Permissions Permissions } +// storedProcedure represents a stored procedure. +type storedProcedure struct { + query string + argsCount int +} + // NewTestServer returns a new instance of a test Postgres server. func NewTestServer(config common.TestServerConfig) (svr *TestServer, err error) { err = config.CheckAndSetDefaults() @@ -166,7 +172,7 @@ func NewTestServer(config common.TestServerConfig) (svr *TestServer, err error) }), parametersCh: make(chan map[string]string, 100), pids: make(map[uint32]*pidHandle), - storedProcedures: make(map[string]string), + storedProcedures: make(map[string]*storedProcedure), userEventsCh: make(chan UserEvent, 100), userPermissionEventsCh: make(chan UserPermissionEvent, 100), allowedUsers: &allowedUsers, @@ -298,23 +304,35 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp // Following messages are for handling Postgres extended query // protocol flow used by prepared statements. case *pgproto3.Parse: - switch msg.Query { - case activateQuery: - if err := s.handleActivateUser(client); err != nil { - s.log.WithError(err).Error("Failed to handle user activation.") - } - case deleteQuery: - if err := s.handleDeactivateUser(client, true); err != nil { - s.log.WithError(err).Error("Failed to handle user deletion.") + schema, procName, argsCount, ok := processProcedureCall(msg.Query) + if ok { + if !s.hasProcedure(pid, schema, procName, argsCount) { + return trace.BadParameter("procedure %q on schema %q wasn't created before the call for PID %d", procName, schema, pid) } - case deactivateQuery: - if err := s.handleDeactivateUser(client, false); err != nil { - s.log.WithError(err).Error("Failed to handle user deactivation.") - } - case updatePermissionsQuery: - if err := s.handleUpdatePermissions(client); err != nil { - s.log.WithError(err).Error("Failed to handle user permissions update.") + + switch procName { + case activateProcName: + if err := s.handleActivateUser(client); err != nil { + s.log.WithError(err).Error("Failed to handle user activation.") + } + case deleteProcName: + if err := s.handleDeactivateUser(client, true); err != nil { + s.log.WithError(err).Error("Failed to handle user deletion.") + } + case deactivateProcName: + if err := s.handleDeactivateUser(client, false); err != nil { + s.log.WithError(err).Error("Failed to handle user deactivation.") + } + case updatePermissionsProcName: + if err := s.handleUpdatePermissions(client); err != nil { + s.log.WithError(err).Error("Failed to handle user permissions update.") + } } + + continue + } + + switch msg.Query { case schemaInfoQuery: if err := s.handleSchemaInfo(client); err != nil { s.log.WithError(err).Error("Failed to handle schema info query.") @@ -378,7 +396,7 @@ func (s *TestServer) handleQuery(client *pgproto3.Backend, query string, pid uin return trace.Wrap(s.fakeLongRunningQuery(client, pid)) } if strings.Contains(strings.ToUpper(query), "CREATE OR REPLACE PROCEDURE") { - if err := s.handleCreateStoredProcedure(query); err != nil { + if err := s.handleCreateStoredProcedure(query, pid); err != nil { return trace.Wrap(err) } } @@ -420,25 +438,65 @@ func (s *TestServer) handleQueryWithError(client *pgproto3.Backend) error { return nil } -func (s *TestServer) handleCreateStoredProcedure(query string) error { +func (s *TestServer) handleCreateStoredProcedure(query string, pid uint32) error { match := storedProcedureRe.FindStringSubmatch(query) - if len(match) != 2 { + if match == nil { return trace.BadParameter("failed to extract stored procedure name from query") } - if _, ok := ephemeralProcs[match[1]]; !ok { - if _, ok := procs[match[1]]; !ok { - return trace.BadParameter("test server doesn't support stored procedure %q", match[1]) + if _, ok := procs[match[storedProcedureRe.SubexpIndex("ProcName")]]; !ok { + return trace.BadParameter("test server doesn't support stored procedure %q", match[1]) + } + + procName := storedProcedureName(pid, match[storedProcedureRe.SubexpIndex("Schema")], match[storedProcedureRe.SubexpIndex("ProcName")]) + var argsCount int + args := strings.Split(match[storedProcedureRe.SubexpIndex("Args")], ",") + for _, arg := range args { + // Skip arguments that have a default value. + if !strings.Contains(strings.ToLower(arg), "default") { + argsCount++ } } - s.log.Debugf("Created stored procedure %q.", match[1]) + s.log.Debugf("Created stored procedure %q.", procName) s.mu.Lock() defer s.mu.Unlock() - s.storedProcedures[match[1]] = query + s.storedProcedures[procName] = &storedProcedure{query: query, argsCount: argsCount} return nil } +func (s *TestServer) hasProcedure(pid uint32, schema, procName string, argsCount int) bool { + s.mu.Lock() + defer s.mu.Unlock() + + storedProcedure, ok := s.storedProcedures[storedProcedureName(pid, schema, procName)] + if !ok { + s.log.Errorf("Procedure %q not found on schema %q", procName, schema) + return false + } + + if argsCount != storedProcedure.argsCount { + s.log.Errorf("Wrong number of arguments for procedure %q call. Expected %d but got %d", procName, storedProcedure.argsCount, argsCount) + return false + } + + return true +} + +func storedProcedureName(pid uint32, schema, procName string) string { + var name string + switch strings.ToLower(schema) { + case "pg_temp": + name = fmt.Sprintf("%d.%s", pid, procName) + case "": + name = procName + default: + name = fmt.Sprintf("%s.%s", schema, procName) + } + + return strings.ToLower(name) +} + // multiMessage wraps *pgproto3.DataRow and implements pgproto3.BackendMessage by writing multiple copies of this message in Encode. type multiMessage struct { singleMessage *pgproto3.DataRow @@ -1071,7 +1129,31 @@ const userParameterName = "user" // storedProcedureRe is the regex for capturing stored procedure name from its // creation query. -var storedProcedureRe = regexp.MustCompile(`(?i)create or replace procedure (.+)\(`) +var storedProcedureRe = regexp.MustCompile(`(?i)create or replace procedure (?:(?P\w+)\.)?(?P.+)\((?P.+)?\)`) // selectBenchmarkRe is the regex for capturing the parameters from the select query used for read benchmark. var selectBenchmarkRe = regexp.MustCompile(`SELECT \* FROM bench\_(\d+) LIMIT (\d+)`) + +// callProcedureRe is the regex for caputuring the schema name, and procedure +// name from the procedure call query. +// Examples: +// - call pg_temp.hello($1) +// - call pg_temp.hello1() +// - call pg_temp.hello2($1, $2, $3) +// - call hello3($1::jsonb) +var callProcedureRe = regexp.MustCompile(`(?i)^call (?:(?P\w+)\.)?(?P\w+)\((?P.+)?\)`) + +// processProcedureCall parses a query and returns the information about the +// the procedure call. +func processProcedureCall(query string) (schema string, procName string, argsCount int, ok bool) { + procMatches := callProcedureRe.FindStringSubmatch(query) + if procMatches == nil { + return + } + + ok = true + schema = procMatches[callProcedureRe.SubexpIndex("Schema")] + procName = procMatches[callProcedureRe.SubexpIndex("ProcName")] + argsCount = len(strings.Split(procMatches[callProcedureRe.SubexpIndex("Args")], ",")) + return +} diff --git a/lib/srv/db/postgres/users.go b/lib/srv/db/postgres/users.go index fdae1d1299823..88ac65d8213f9 100644 --- a/lib/srv/db/postgres/users.go +++ b/lib/srv/db/postgres/users.go @@ -77,12 +77,16 @@ func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) e // bookkeeping group or stored procedures get deleted or changed offband. logger := e.Log.With("user", sessionCtx.DatabaseUser) err = withRetry(ctx, logger, func() error { - return trace.Wrap(e.initAutoUsers(ctx, sessionCtx, conn)) + return trace.Wrap(e.updateAutoUsersRole(ctx, conn)) }) if err != nil { return trace.Wrap(err) } + if err := e.createProcedures(ctx, sessionCtx, conn, []string{activateProcName, deactivateProcName}); err != nil { + return trace.Wrap(err) + } + roles, err := prepareRoles(sessionCtx) if err != nil { return trace.Wrap(err) @@ -90,8 +94,7 @@ func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) e logger.InfoContext(ctx, "Activating PostgreSQL user", "roles", roles) err = withRetry(ctx, logger, func() error { - _, err = conn.Exec(ctx, activateQuery, sessionCtx.DatabaseUser, roles) - return trace.Wrap(err) + return trace.Wrap(e.callProcedure(ctx, sessionCtx, conn, activateProcName, sessionCtx.DatabaseUser, roles)) }) if err != nil { logger.DebugContext(ctx, "Call teleport_activate_user failed.", "error", err) @@ -101,13 +104,6 @@ func (e *Engine) ActivateUser(ctx context.Context, sessionCtx *common.Session) e } e.Audit.OnDatabaseUserCreate(ctx, sessionCtx, nil) - if err != nil { - if strings.Contains(err.Error(), "already exists") { - return trace.AlreadyExists("user %q already exists in this PostgreSQL database and is not managed by Teleport", sessionCtx.DatabaseUser) - } - return trace.Wrap(err) - } - err = e.applyPermissions(ctx, sessionCtx) if err != nil { logger.WarnContext(e.Context, "Failed to apply permissions.", "error", err) @@ -241,22 +237,11 @@ func (e *Engine) applyPermissions(ctx context.Context, sessionCtx *common.Sessio return trace.Wrap(err) } - // teleport_remove_permissions and teleport_update_permissions are created in pg_temp table of the session database. - // teleport_remove_permissions gets called by teleport_update_permissions as needed. - _, err = conn.Exec(ctx, removePermissionsProc) - if err != nil { - e.Log.ErrorContext(e.Context, "Creating temporary stored procedure failed.", "procedure", removePermissionsProcName, "error", err) - return trace.Wrap(err) - } - - _, err = conn.Exec(ctx, updatePermissionsProc) - if err != nil { - e.Log.ErrorContext(e.Context, "Creating temporary stored procedure failed.", "procedure", updatePermissionsProcName, "error", err) + if err := e.createProcedures(ctx, sessionCtx, conn, []string{removePermissionsProcName, updatePermissionsProcName}); err != nil { return trace.Wrap(err) } - _, err = conn.Exec(ctx, updatePermissionsQuery, sessionCtx.DatabaseUser, perms) - if err != nil { + if err := e.callProcedure(ctx, sessionCtx, conn, updatePermissionsProcName, sessionCtx.DatabaseUser, perms); err != nil { var pgErr *pq.Error if errors.As(err, &pgErr) { if pgErr.Code == common.SQLStatePermissionsChanged { @@ -285,14 +270,11 @@ func (e *Engine) removePermissions(ctx context.Context, sessionCtx *common.Sessi defer conn.Close(ctx) // teleport_remove_permissions is created in pg_temp table of the session database. - _, err = conn.Exec(ctx, removePermissionsProc) - if err != nil { - logger.ErrorContext(e.Context, "Creating temporary stored procedure failed.", "procedure", removePermissionsProcName, "error", err) + if err := e.createProcedures(ctx, sessionCtx, conn, []string{removePermissionsProcName}); err != nil { return trace.Wrap(err) } - _, err = conn.Exec(ctx, removePermissionsQuery, sessionCtx.DatabaseUser) - if err != nil { + if err := e.callProcedure(ctx, sessionCtx, conn, removePermissionsProcName, sessionCtx.DatabaseUser); err != nil { logger.ErrorContext(ctx, "Removing permissions from user failed.", "error", err) return trace.Wrap(err) } @@ -314,11 +296,14 @@ func (e *Engine) DeactivateUser(ctx context.Context, sessionCtx *common.Session) } defer conn.Close(ctx) + if err := e.createProcedures(ctx, sessionCtx, conn, []string{deactivateProcName}); err != nil { + return trace.Wrap(err) + } + logger := e.Log.With("user", sessionCtx.DatabaseUser) logger.InfoContext(ctx, "Deactivating PostgreSQL user.") err = withRetry(ctx, logger, func() error { - _, err = conn.Exec(ctx, deactivateQuery, sessionCtx.DatabaseUser) - return trace.Wrap(err) + return trace.Wrap(e.callProcedure(ctx, sessionCtx, conn, deactivateProcName, sessionCtx.DatabaseUser)) }) if err != nil { e.Audit.OnDatabaseUserDeactivate(ctx, sessionCtx, false, err) @@ -344,6 +329,10 @@ func (e *Engine) DeleteUser(ctx context.Context, sessionCtx *common.Session) err } defer conn.Close(ctx) + if err := e.createProcedures(ctx, sessionCtx, conn, []string{deleteProcName, deactivateProcName}); err != nil { + return trace.Wrap(err) + } + logger := e.Log.With("user", sessionCtx.DatabaseUser) logger.InfoContext(ctx, "Deleting PostgreSQL user.") @@ -353,6 +342,10 @@ func (e *Engine) DeleteUser(ctx context.Context, sessionCtx *common.Session) err case sessionCtx.Database.IsRedshift(): return trace.Wrap(e.deleteUserRedshift(ctx, sessionCtx, conn, &state)) default: + deleteQuery, err := buildCallQuery(sessionCtx, deleteProcName) + if err != nil { + return trace.Wrap(err) + } return trace.Wrap(conn.QueryRow(ctx, deleteQuery, sessionCtx.DatabaseUser).Scan(&state)) } }) @@ -382,7 +375,7 @@ func (e *Engine) DeleteUser(ctx context.Context, sessionCtx *common.Session) err // into the returned error instead of doing this on state returned (like regular // PostgreSQL). func (e *Engine) deleteUserRedshift(ctx context.Context, sessionCtx *common.Session, conn *pgx.Conn, state *string) error { - _, err := conn.Exec(ctx, deleteQuery, sessionCtx.DatabaseUser) + err := e.callProcedure(ctx, sessionCtx, conn, deleteProcName, sessionCtx.DatabaseUser) if err == nil { *state = common.SQLStateUserDropped return nil @@ -399,10 +392,9 @@ func (e *Engine) deleteUserRedshift(ctx context.Context, sessionCtx *common.Sess return trace.Wrap(err) } -// initAutoUsers installs procedures for activating and deactivating users and -// creates the bookkeeping role for auto-provisioned users. -func (e *Engine) initAutoUsers(ctx context.Context, sessionCtx *common.Session, conn *pgx.Conn) error { - // Create a role/group which all auto-created users will be a part of. +// updateAutoUsersRole ensures the bookkeeping role for auto-provisioned users +// is present. +func (e *Engine) updateAutoUsersRole(ctx context.Context, conn *pgx.Conn) error { _, err := conn.Exec(ctx, fmt.Sprintf("create role %q", teleportAutoUserRole)) if err != nil { if !strings.Contains(err.Error(), "already exists") { @@ -413,14 +405,6 @@ func (e *Engine) initAutoUsers(ctx context.Context, sessionCtx *common.Session, e.Log.DebugContext(ctx, "Created PostgreSQL role.", "role", teleportAutoUserRole) } - // Install stored procedures for creating and disabling database users. - for name, sql := range pickProcedures(sessionCtx) { - _, err := conn.Exec(ctx, sql) - if err != nil { - return trace.Wrap(err) - } - e.Log.DebugContext(ctx, "Installed PostgreSQL stored procedure.", "procedure", name) - } return nil } @@ -439,6 +423,84 @@ func (e *Engine) pgxConnect(ctx context.Context, sessionCtx *common.Session) (*p return pgx.ConnectConfig(ctx, pgxConf) } +// callProcedure calls the procedure with the provided arguments. +func (e *Engine) callProcedure(ctx context.Context, sessionCtx *common.Session, conn *pgx.Conn, procName string, args ...any) error { + query, err := buildCallQuery(sessionCtx, procName) + if err != nil { + return trace.Wrap(err) + } + + _, err = conn.Exec(ctx, query, args...) + return trace.Wrap(err) +} + +// createProcedures executes the create procedures for the provided list of +// procedures. +func (e *Engine) createProcedures(ctx context.Context, sessionCtx *common.Session, conn *pgx.Conn, procNames []string) error { + selectedProcs := pickProcedures(sessionCtx) + + for _, procName := range procNames { + proc, ok := selectedProcs[procName] + if !ok { + return trace.NotImplemented("procedure %q is not available for %s databases", procName, sessionCtx.Database.GetType()) + } + + logger := e.Log.With("procedure", procName) + err := withRetry(ctx, logger, func() error { + _, err := conn.Exec(ctx, proc) + return trace.Wrap(err) + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to install procedure.") + return trace.Wrap(err) + } + + logger.DebugContext(ctx, "Installed procedure.") + } + + return nil +} + +// buildCallQuery builds the call query based on the procedure name and session. +func buildCallQuery(sessionCtx *common.Session, procName string) (string, error) { + if _, ok := pickProcedures(sessionCtx)[procName]; !ok { + return "", trace.NotImplemented("procedure %q is not available for %s databases", procName, sessionCtx.Database.GetType()) + } + + var schema string + switch { + case sessionCtx.Database.IsRedshift(): + // TODO(gabrielcorado): support customizing the schema the procedures + // will be stored on RedShift. For now, let the database decide where + // to store them. + schema = "" + default: + // Always use `pg_temp` if the database type supports it. This reduces + // the number of permissions required by the admin user. + schema = "pg_temp" + } + + var procCall string + switch procName { + case activateProcName: + procCall = activateProcCall + case deactivateProcName: + procCall = deactivateProcCall + case deleteProcName: + procCall = deleteProcCall + case updatePermissionsProcName: + procCall = updatePermissionsProcCall + case removePermissionsProcName: + procCall = removePermissionsProcCall + } + + if schema != "" { + return fmt.Sprintf("call %s.%s", schema, procCall), nil + } + + return "call " + procCall, nil +} + func prepareRoles(sessionCtx *common.Session) (any, error) { switch sessionCtx.Database.GetType() { case types.DatabaseTypeRDS: @@ -489,10 +551,10 @@ const ( deleteProcName = "teleport_delete_user" // updatePermissionsProcName is the name of the stored procedure Teleport will use // to automatically update database permissions. - updatePermissionsProcName = "pg_temp.teleport_update_permissions" + updatePermissionsProcName = "teleport_update_permissions" // removePermissionsProcName is the name of the stored procedure Teleport will use // to automatically remove all database permissions. - removePermissionsProcName = "pg_temp.teleport_remove_permissions" + removePermissionsProcName = "teleport_remove_permissions" // teleportAutoUserRole is the name of a PostgreSQL role that all Teleport // managed users will be a part of. teleportAutoUserRole = "teleport-auto-user" @@ -501,18 +563,21 @@ const ( var ( //go:embed sql/activate-user.sql activateProc string - // activateQuery is the query for calling user activation procedure. - activateQuery = fmt.Sprintf(`call %v($1, $2)`, activateProcName) + // activateProcCall contains the procedure name and arguments used to call + // the activate user procedure. + activateProcCall = fmt.Sprintf(`%v($1, $2)`, activateProcName) //go:embed sql/deactivate-user.sql deactivateProc string - // deactivateQuery is the query for calling user deactivation procedure. - deactivateQuery = fmt.Sprintf(`call %v($1)`, deactivateProcName) + // deactivateProcCall contains the procedure name and arguments used to call + // the deactivate user procedure. + deactivateProcCall = fmt.Sprintf(`%v($1)`, deactivateProcName) //go:embed sql/delete-user.sql deleteProc string - // deleteQuery is the query for calling user deletion procedure. - deleteQuery = fmt.Sprintf(`call %v($1)`, deleteProcName) + // deleteProcCall contains the procedure name and arguments used to call + // the delete user procedure. + deleteProcCall = fmt.Sprintf(`%v($1)`, deleteProcName) //go:embed sql/redshift-activate-user.sql redshiftActivateProc string @@ -523,20 +588,22 @@ var ( //go:embed sql/update-permissions.sql updatePermissionsProc string - // updatePermissionsQuery is the query for calling update permissions procedure. - // the procedure is created on demand in the pg_temp table in the session database. - updatePermissionsQuery = fmt.Sprintf(`call %v($1, $2::jsonb)`, updatePermissionsProcName) + // updatePermissionsProcCall contains the procedure name and arguments used + // to call the update permissions procedure. + updatePermissionsProcCall = fmt.Sprintf(`%v($1, $2::jsonb)`, updatePermissionsProcName) //go:embed sql/remove-permissions.sql removePermissionsProc string - // removePermissionsQuery is the query for calling update permissions procedure. - // the procedure is created on demand in the pg_temp table in the session database. - removePermissionsQuery = fmt.Sprintf(`call %v($1)`, removePermissionsProcName) + // removePermissionsProcCall contains the procedure name and arguments used + // to call the remove permissions procedure. + removePermissionsProcCall = fmt.Sprintf(`%v($1)`, removePermissionsProcName) procs = map[string]string{ - activateProcName: activateProc, - deactivateProcName: deactivateProc, - deleteProcName: deleteProc, + activateProcName: activateProc, + deactivateProcName: deactivateProc, + deleteProcName: deleteProc, + updatePermissionsProcName: updatePermissionsProc, + removePermissionsProcName: removePermissionsProc, } redshiftProcs = map[string]string{ @@ -544,11 +611,6 @@ var ( deactivateProcName: redshiftDeactivateProc, deleteProcName: redshiftDeleteProc, } - - ephemeralProcs = map[string]string{ - updatePermissionsProcName: updatePermissionsProc, - removePermissionsProcName: removePermissionsProc, - } ) // withRetry is a helper for auto user operations that runs a given func a From 950731c3655e7fe75c38ed66c6cc716b9c9f5da1 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 18 Jul 2024 12:16:42 -0300 Subject: [PATCH 2/6] refactor(postgres): code review suggestions --- lib/srv/db/postgres/test.go | 5 +++++ lib/srv/db/postgres/users.go | 7 ++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/srv/db/postgres/test.go b/lib/srv/db/postgres/test.go index 5b46a7731efe2..5a43aa294d556 100644 --- a/lib/srv/db/postgres/test.go +++ b/lib/srv/db/postgres/test.go @@ -1145,6 +1145,11 @@ var callProcedureRe = regexp.MustCompile(`(?i)^call (?:(?P\w+)\.)?(?P Date: Thu, 18 Jul 2024 12:31:32 -0300 Subject: [PATCH 3/6] test(postgres): replace logrus with slog --- lib/srv/db/postgres/test.go | 69 ++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/lib/srv/db/postgres/test.go b/lib/srv/db/postgres/test.go index 5a43aa294d556..3adec81931ed2 100644 --- a/lib/srv/db/postgres/test.go +++ b/lib/srv/db/postgres/test.go @@ -24,6 +24,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "log/slog" "net" "regexp" "strconv" @@ -36,7 +37,6 @@ import ( "github.com/jackc/pgerrcode" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" @@ -77,7 +77,7 @@ type TestServer struct { listener net.Listener port string tlsConfig *tls.Config - log logrus.FieldLogger + log *slog.Logger // queryCount keeps track of the number of queries the server has received. queryCount uint32 // parametersCh receives startup message connection parameters. @@ -166,10 +166,10 @@ func NewTestServer(config common.TestServerConfig) (svr *TestServer, err error) listener: config.Listener, port: port, tlsConfig: tlsConfig, - log: logrus.WithFields(logrus.Fields{ - teleport.ComponentKey: defaults.ProtocolPostgres, - "name": config.Name, - }), + log: slog.Default().With( + teleport.ComponentKey, defaults.ProtocolPostgres, + "name", config.Name, + ), parametersCh: make(chan map[string]string, 100), pids: make(map[uint32]*pidHandle), storedProcedures: make(map[string]*storedProcedure), @@ -182,7 +182,7 @@ func NewTestServer(config common.TestServerConfig) (svr *TestServer, err error) // Serve starts serving client connections. func (s *TestServer) Serve() error { - s.log.Debugf("Starting test Postgres server on %v.", s.listener.Addr()) + s.log.Debug("Starting test Postgres server.", "address", s.listener.Addr()) defer s.log.Debug("Test Postgres server stopped.") for { conn, err := s.listener.Accept() @@ -190,7 +190,7 @@ func (s *TestServer) Serve() error { if utils.IsOKNetworkError(err) { return nil } - s.log.WithError(err).Error("Failed to accept connection.") + s.log.Error("Failed to accept connection.", "error", err) continue } s.log.Debug("Accepted connection.") @@ -199,8 +199,7 @@ func (s *TestServer) Serve() error { defer conn.Close() err = s.handleConnection(conn) if err != nil { - s.log.Errorf("Failed to handle connection: %v.", - trace.DebugReport(err)) + s.log.Error("Failed to handle connection.", "debug_report", trace.DebugReport(err)) } }() } @@ -216,7 +215,7 @@ func (s *TestServer) handleConnection(conn net.Conn) error { if err != nil { return trace.Wrap(err) } - s.log.Debugf("Received %#v.", startupMessage) + s.log.Debug("Received.", "message", fmt.Sprintf("%#v", startupMessage)) switch msg := startupMessage.(type) { case *pgproto3.StartupMessage: return s.handleStartup(client, msg) @@ -238,7 +237,7 @@ func (s *TestServer) startTLS(conn net.Conn) (*pgproto3.Backend, error) { if _, ok := startupMessage.(*pgproto3.SSLRequest); !ok { return nil, trace.BadParameter("expected *pgproto3.SSLRequest, got: %#v", startupMessage) } - s.log.Debugf("Received %#v.", startupMessage) + s.log.Debug("Received.", "message", fmt.Sprintf("%#v", startupMessage)) // Reply with 'S' to indicate TLS support. if _, err := conn.Write([]byte("S")); err != nil { return nil, trace.Wrap(err) @@ -299,7 +298,7 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp switch msg := message.(type) { case *pgproto3.Query: if err := s.handleQuery(client, msg.String, pid); err != nil { - s.log.WithError(err).Error("Failed to handle query.") + s.log.Error("Failed to handle query.", "error", err) } // Following messages are for handling Postgres extended query // protocol flow used by prepared statements. @@ -313,19 +312,19 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp switch procName { case activateProcName: if err := s.handleActivateUser(client); err != nil { - s.log.WithError(err).Error("Failed to handle user activation.") + s.log.Error("Failed to handle user activation.", "error", err) } case deleteProcName: if err := s.handleDeactivateUser(client, true); err != nil { - s.log.WithError(err).Error("Failed to handle user deletion.") + s.log.Error("Failed to handle user deletion.", "error", err) } case deactivateProcName: if err := s.handleDeactivateUser(client, false); err != nil { - s.log.WithError(err).Error("Failed to handle user deactivation.") + s.log.Error("Failed to handle user deactivation.", "error", err) } case updatePermissionsProcName: if err := s.handleUpdatePermissions(client); err != nil { - s.log.WithError(err).Error("Failed to handle user permissions update.") + s.log.Error("Failed to handle user permissions update.", "error", err) } } @@ -335,21 +334,21 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp switch msg.Query { case schemaInfoQuery: if err := s.handleSchemaInfo(client); err != nil { - s.log.WithError(err).Error("Failed to handle schema info query.") + s.log.Error("Failed to handle schema info query.", "error", err) } default: - s.log.Warnf("Ignoring PARSE message for query %q", msg.Query) + s.log.Warn("Ignoring PARSE message", "query", msg.Query) } case *pgproto3.Bind: case *pgproto3.Describe: case *pgproto3.Sync: if err := s.handleSync(client); err != nil { - s.log.WithError(err).Error("Failed to handle sync.") + s.log.Error("Failed to handle sync.", "error", err) } case *pgproto3.Execute: // Execute executes prepared statement. if err := s.handleQuery(client, "", pid); err != nil { - s.log.WithError(err).Error("Failed to handle query.") + s.log.Error("Failed to handle query.", "error", err) } case *pgproto3.Terminate: return nil @@ -414,7 +413,7 @@ func (s *TestServer) handleQuery(client *pgproto3.Backend, query string, pid uin &pgproto3.ReadyForQuery{}, } for _, message := range messages { - s.log.Debugf("Sending %#v.", message) + s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -428,7 +427,7 @@ func (s *TestServer) handleQueryWithError(client *pgproto3.Backend) error { &pgproto3.ErrorResponse{Severity: "ERROR", Code: "42703", Message: "error"}, &pgproto3.ReadyForQuery{}, } { - s.log.Debugf("Sending %#v.", message) + s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -458,7 +457,7 @@ func (s *TestServer) handleCreateStoredProcedure(query string, pid uint32) error } } - s.log.Debugf("Created stored procedure %q.", procName) + s.log.Debug("Created stored procedure.", "procedure", procName) s.mu.Lock() defer s.mu.Unlock() s.storedProcedures[procName] = &storedProcedure{query: query, argsCount: argsCount} @@ -471,12 +470,12 @@ func (s *TestServer) hasProcedure(pid uint32, schema, procName string, argsCount storedProcedure, ok := s.storedProcedures[storedProcedureName(pid, schema, procName)] if !ok { - s.log.Errorf("Procedure %q not found on schema %q", procName, schema) + s.log.Error("Procedure not found", "procedure", procName, "schema", schema) return false } if argsCount != storedProcedure.argsCount { - s.log.Errorf("Wrong number of arguments for procedure %q call. Expected %d but got %d", procName, storedProcedure.argsCount, argsCount) + s.log.Error("Wrong number of arguments for procedure call", "procedure", procName, "expected_args", storedProcedure.argsCount, "args_provided", argsCount) return false } @@ -571,7 +570,7 @@ func (s *TestServer) handleBenchmarkQuery(query string, client *pgproto3.Backend return trace.Wrap(err) } - s.log.Debugf("Responding to query %q, will send %v messages, total length %v", query, repeats, len(mm.payload)) + s.log.Debug("Responding to query", query, "repeat", repeats, "length", len(mm.payload)) // preamble err = client.Send(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("dummy")}}}) @@ -596,7 +595,7 @@ func (s *TestServer) handleBenchmarkQuery(query string, client *pgproto3.Backend return trace.Wrap(err) } - s.log.Debugf("Finished handling query %q", query) + s.log.Debug("Finished handling query", "query", query) return nil } @@ -661,7 +660,7 @@ func (s *TestServer) handleActivateUser(client *pgproto3.Backend) error { return trace.Wrap(err) } // Mark the user as active. - s.log.Debugf("Activated user %q with roles %v.", name, roles) + s.log.Debug("Activated user.", "user", name, "roles", roles) s.userEventsCh <- UserEvent{Name: name, Roles: roles, Active: true} s.allowedUsers.Store(name, struct{}{}) return nil @@ -734,7 +733,7 @@ func (s *TestServer) handleDeactivateUser(client *pgproto3.Backend, sendDeleteRe return trace.Wrap(err) } // Mark the user as active. - s.log.Debugf("Deactivated user %q.", name) + s.log.Debug("Deactivated user.", "user", name) s.userEventsCh <- UserEvent{Name: name, Active: false} s.allowedUsers.Delete(name) return nil @@ -800,7 +799,7 @@ func (s *TestServer) handleUpdatePermissions(client *pgproto3.Backend) error { return trace.Wrap(err) } // Mark the user as active. - s.log.Debugf("Updated permissions for user %q with permissions %#v.", name, perms) + s.log.Debug("Updated permissions for user.", "user", name, "permissions", fmt.Sprintf("%#v", perms)) s.userPermissionEventsCh <- UserPermissionEvent{Name: name, Permissions: perms} return nil } @@ -932,7 +931,7 @@ func (s *TestServer) receiveFrontendMessage(client *pgproto3.Backend) (pgproto3. if err != nil { return nil, trace.Wrap(err) } - s.log.Debugf("Received %#v.", message) + s.log.Debug("Received.", "message", fmt.Sprintf("%#v", message)) return message, nil } @@ -978,7 +977,7 @@ func getJSONB[T any](formatCode int16, src []byte) (T, error) { func (s *TestServer) sendMessages(client *pgproto3.Backend, messages ...pgproto3.BackendMessage) error { for _, message := range messages { - s.log.Debugf("Sending %#v.", message) + s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -1002,7 +1001,7 @@ func (s *TestServer) fakeLongRunningQuery(client *pgproto3.Backend, pid uint32) &pgproto3.ReadyForQuery{}, } for _, message := range messages { - s.log.Debugf("Sending %#v.", message) + s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -1013,7 +1012,7 @@ func (s *TestServer) fakeLongRunningQuery(client *pgproto3.Backend, pid uint32) func (s *TestServer) handleSync(client *pgproto3.Backend) error { message := &pgproto3.ReadyForQuery{} - s.log.Debugf("Sending %#v.", message) + s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) From 51672652dfd1462e7eb2411ddcb95b0b1389dfb5 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 18 Jul 2024 12:38:24 -0300 Subject: [PATCH 4/6] chore(postgres): move switch statement to a map --- lib/srv/db/postgres/users.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lib/srv/db/postgres/users.go b/lib/srv/db/postgres/users.go index 33756a462cd0e..355e384f4c695 100644 --- a/lib/srv/db/postgres/users.go +++ b/lib/srv/db/postgres/users.go @@ -477,18 +477,9 @@ func buildCallQuery(sessionCtx *common.Session, procName string) (string, error) schema = "pg_temp" } - var procCall string - switch procName { - case activateProcName: - procCall = activateProcCall - case deactivateProcName: - procCall = deactivateProcCall - case deleteProcName: - procCall = deleteProcCall - case updatePermissionsProcName: - procCall = updatePermissionsProcCall - case removePermissionsProcName: - procCall = removePermissionsProcCall + procCall, ok := procsCall[procName] + if !ok { + return "", trace.BadParameter("procedure %q doesn't have a call statement", procName) } if schema != "" { @@ -608,6 +599,15 @@ var ( deactivateProcName: redshiftDeactivateProc, deleteProcName: redshiftDeleteProc, } + + // procsCall maps procedures names to their call statements. + procsCall = map[string]string{ + activateProcName: activateProcCall, + deactivateProcName: deactivateProcCall, + deleteProcName: deleteProcCall, + updatePermissionsProcName: updatePermissionsProcCall, + removePermissionsProcName: removePermissionsProcCall, + } ) // withRetry is a helper for auto user operations that runs a given func a From 3a32c4cb4e5813630fe185400a4f9f6f5aa1e0ac Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 18 Jul 2024 14:08:11 -0300 Subject: [PATCH 5/6] test(postgres): with wrong slog call --- lib/srv/db/postgres/test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/srv/db/postgres/test.go b/lib/srv/db/postgres/test.go index 3adec81931ed2..980474c635642 100644 --- a/lib/srv/db/postgres/test.go +++ b/lib/srv/db/postgres/test.go @@ -570,7 +570,7 @@ func (s *TestServer) handleBenchmarkQuery(query string, client *pgproto3.Backend return trace.Wrap(err) } - s.log.Debug("Responding to query", query, "repeat", repeats, "length", len(mm.payload)) + s.log.Debug("Responding to query", "query", query, "repeat", repeats, "length", len(mm.payload)) // preamble err = client.Send(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("dummy")}}}) From 33026412910e747ef53bd135657fc3e954cda807 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 18 Jul 2024 15:10:35 -0300 Subject: [PATCH 6/6] test(postgres): use slog context functions --- lib/srv/db/postgres/test.go | 62 ++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/lib/srv/db/postgres/test.go b/lib/srv/db/postgres/test.go index 980474c635642..955bb39e2c7a2 100644 --- a/lib/srv/db/postgres/test.go +++ b/lib/srv/db/postgres/test.go @@ -182,24 +182,24 @@ func NewTestServer(config common.TestServerConfig) (svr *TestServer, err error) // Serve starts serving client connections. func (s *TestServer) Serve() error { - s.log.Debug("Starting test Postgres server.", "address", s.listener.Addr()) - defer s.log.Debug("Test Postgres server stopped.") + s.log.DebugContext(context.Background(), "Starting test Postgres server.", "address", s.listener.Addr()) + defer s.log.DebugContext(context.Background(), "Test Postgres server stopped.") for { conn, err := s.listener.Accept() if err != nil { if utils.IsOKNetworkError(err) { return nil } - s.log.Error("Failed to accept connection.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to accept connection.", "error", err) continue } - s.log.Debug("Accepted connection.") + s.log.DebugContext(context.Background(), "Accepted connection.") go func() { - defer s.log.Debug("Connection done.") + defer s.log.DebugContext(context.Background(), "Connection done.") defer conn.Close() err = s.handleConnection(conn) if err != nil { - s.log.Error("Failed to handle connection.", "debug_report", trace.DebugReport(err)) + s.log.ErrorContext(context.Background(), "Failed to handle connection.", "debug_report", trace.DebugReport(err)) } }() } @@ -215,7 +215,7 @@ func (s *TestServer) handleConnection(conn net.Conn) error { if err != nil { return trace.Wrap(err) } - s.log.Debug("Received.", "message", fmt.Sprintf("%#v", startupMessage)) + s.log.DebugContext(context.Background(), "Received.", "message", fmt.Sprintf("%#v", startupMessage)) switch msg := startupMessage.(type) { case *pgproto3.StartupMessage: return s.handleStartup(client, msg) @@ -237,7 +237,7 @@ func (s *TestServer) startTLS(conn net.Conn) (*pgproto3.Backend, error) { if _, ok := startupMessage.(*pgproto3.SSLRequest); !ok { return nil, trace.BadParameter("expected *pgproto3.SSLRequest, got: %#v", startupMessage) } - s.log.Debug("Received.", "message", fmt.Sprintf("%#v", startupMessage)) + s.log.DebugContext(context.Background(), "Received.", "message", fmt.Sprintf("%#v", startupMessage)) // Reply with 'S' to indicate TLS support. if _, err := conn.Write([]byte("S")); err != nil { return nil, trace.Wrap(err) @@ -298,7 +298,7 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp switch msg := message.(type) { case *pgproto3.Query: if err := s.handleQuery(client, msg.String, pid); err != nil { - s.log.Error("Failed to handle query.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle query.", "error", err) } // Following messages are for handling Postgres extended query // protocol flow used by prepared statements. @@ -312,19 +312,19 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp switch procName { case activateProcName: if err := s.handleActivateUser(client); err != nil { - s.log.Error("Failed to handle user activation.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle user activation.", "error", err) } case deleteProcName: if err := s.handleDeactivateUser(client, true); err != nil { - s.log.Error("Failed to handle user deletion.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle user deletion.", "error", err) } case deactivateProcName: if err := s.handleDeactivateUser(client, false); err != nil { - s.log.Error("Failed to handle user deactivation.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle user deactivation.", "error", err) } case updatePermissionsProcName: if err := s.handleUpdatePermissions(client); err != nil { - s.log.Error("Failed to handle user permissions update.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle user permissions update.", "error", err) } } @@ -334,21 +334,21 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp switch msg.Query { case schemaInfoQuery: if err := s.handleSchemaInfo(client); err != nil { - s.log.Error("Failed to handle schema info query.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle schema info query.", "error", err) } default: - s.log.Warn("Ignoring PARSE message", "query", msg.Query) + s.log.WarnContext(context.Background(), "Ignoring PARSE message", "query", msg.Query) } case *pgproto3.Bind: case *pgproto3.Describe: case *pgproto3.Sync: if err := s.handleSync(client); err != nil { - s.log.Error("Failed to handle sync.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle sync.", "error", err) } case *pgproto3.Execute: // Execute executes prepared statement. if err := s.handleQuery(client, "", pid); err != nil { - s.log.Error("Failed to handle query.", "error", err) + s.log.ErrorContext(context.Background(), "Failed to handle query.", "error", err) } case *pgproto3.Terminate: return nil @@ -413,7 +413,7 @@ func (s *TestServer) handleQuery(client *pgproto3.Backend, query string, pid uin &pgproto3.ReadyForQuery{}, } for _, message := range messages { - s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) + s.log.DebugContext(context.Background(), "Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -427,7 +427,7 @@ func (s *TestServer) handleQueryWithError(client *pgproto3.Backend) error { &pgproto3.ErrorResponse{Severity: "ERROR", Code: "42703", Message: "error"}, &pgproto3.ReadyForQuery{}, } { - s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) + s.log.DebugContext(context.Background(), "Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -457,7 +457,7 @@ func (s *TestServer) handleCreateStoredProcedure(query string, pid uint32) error } } - s.log.Debug("Created stored procedure.", "procedure", procName) + s.log.DebugContext(context.Background(), "Created stored procedure.", "procedure", procName) s.mu.Lock() defer s.mu.Unlock() s.storedProcedures[procName] = &storedProcedure{query: query, argsCount: argsCount} @@ -470,12 +470,12 @@ func (s *TestServer) hasProcedure(pid uint32, schema, procName string, argsCount storedProcedure, ok := s.storedProcedures[storedProcedureName(pid, schema, procName)] if !ok { - s.log.Error("Procedure not found", "procedure", procName, "schema", schema) + s.log.ErrorContext(context.Background(), "Procedure not found", "procedure", procName, "schema", schema) return false } if argsCount != storedProcedure.argsCount { - s.log.Error("Wrong number of arguments for procedure call", "procedure", procName, "expected_args", storedProcedure.argsCount, "args_provided", argsCount) + s.log.ErrorContext(context.Background(), "Wrong number of arguments for procedure call", "procedure", procName, "expected_args", storedProcedure.argsCount, "args_provided", argsCount) return false } @@ -570,7 +570,7 @@ func (s *TestServer) handleBenchmarkQuery(query string, client *pgproto3.Backend return trace.Wrap(err) } - s.log.Debug("Responding to query", "query", query, "repeat", repeats, "length", len(mm.payload)) + s.log.DebugContext(context.Background(), "Responding to query", "query", query, "repeat", repeats, "length", len(mm.payload)) // preamble err = client.Send(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("dummy")}}}) @@ -595,7 +595,7 @@ func (s *TestServer) handleBenchmarkQuery(query string, client *pgproto3.Backend return trace.Wrap(err) } - s.log.Debug("Finished handling query", "query", query) + s.log.DebugContext(context.Background(), "Finished handling query", "query", query) return nil } @@ -660,7 +660,7 @@ func (s *TestServer) handleActivateUser(client *pgproto3.Backend) error { return trace.Wrap(err) } // Mark the user as active. - s.log.Debug("Activated user.", "user", name, "roles", roles) + s.log.DebugContext(context.Background(), "Activated user.", "user", name, "roles", roles) s.userEventsCh <- UserEvent{Name: name, Roles: roles, Active: true} s.allowedUsers.Store(name, struct{}{}) return nil @@ -733,7 +733,7 @@ func (s *TestServer) handleDeactivateUser(client *pgproto3.Backend, sendDeleteRe return trace.Wrap(err) } // Mark the user as active. - s.log.Debug("Deactivated user.", "user", name) + s.log.DebugContext(context.Background(), "Deactivated user.", "user", name) s.userEventsCh <- UserEvent{Name: name, Active: false} s.allowedUsers.Delete(name) return nil @@ -799,7 +799,7 @@ func (s *TestServer) handleUpdatePermissions(client *pgproto3.Backend) error { return trace.Wrap(err) } // Mark the user as active. - s.log.Debug("Updated permissions for user.", "user", name, "permissions", fmt.Sprintf("%#v", perms)) + s.log.DebugContext(context.Background(), "Updated permissions for user.", "user", name, "permissions", fmt.Sprintf("%#v", perms)) s.userPermissionEventsCh <- UserPermissionEvent{Name: name, Permissions: perms} return nil } @@ -931,7 +931,7 @@ func (s *TestServer) receiveFrontendMessage(client *pgproto3.Backend) (pgproto3. if err != nil { return nil, trace.Wrap(err) } - s.log.Debug("Received.", "message", fmt.Sprintf("%#v", message)) + s.log.DebugContext(context.Background(), "Received.", "message", fmt.Sprintf("%#v", message)) return message, nil } @@ -977,7 +977,7 @@ func getJSONB[T any](formatCode int16, src []byte) (T, error) { func (s *TestServer) sendMessages(client *pgproto3.Backend, messages ...pgproto3.BackendMessage) error { for _, message := range messages { - s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) + s.log.DebugContext(context.Background(), "Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -1001,7 +1001,7 @@ func (s *TestServer) fakeLongRunningQuery(client *pgproto3.Backend, pid uint32) &pgproto3.ReadyForQuery{}, } for _, message := range messages { - s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) + s.log.DebugContext(context.Background(), "Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err) @@ -1012,7 +1012,7 @@ func (s *TestServer) fakeLongRunningQuery(client *pgproto3.Backend, pid uint32) func (s *TestServer) handleSync(client *pgproto3.Backend) error { message := &pgproto3.ReadyForQuery{} - s.log.Debug("Sending.", "message", fmt.Sprintf("%#v", message)) + s.log.DebugContext(context.Background(), "Sending.", "message", fmt.Sprintf("%#v", message)) err := client.Send(message) if err != nil { return trace.Wrap(err)