diff --git a/lib/srv/db/autousers_test.go b/lib/srv/db/autousers_test.go index 8752823c660cf..bc88a19f78f5b 100644 --- a/lib/srv/db/autousers_test.go +++ b/lib/srv/db/autousers_test.go @@ -28,9 +28,11 @@ import ( dbobjectimportrulev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobjectimportrule/v1" "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/label" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth" + libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv/db/common/databaseobjectimportrule" "github.com/gravitational/teleport/lib/srv/db/mongodb" "github.com/gravitational/teleport/lib/srv/db/postgres" @@ -214,6 +216,10 @@ func TestAutoUsersPostgres(t *testing.T) { case <-time.After(5 * time.Second): t.Fatal("user not deactivated after 5s") } + + ev := waitForDatabaseUserDeactivateEvent(t, testCtx) + require.Equal(t, "alice", ev.User) + require.Equal(t, "alice", ev.DatabaseUser) }) } } @@ -368,6 +374,10 @@ func TestAutoUsersMySQL(t *testing.T) { case <-time.After(5 * time.Second): t.Fatal("user not deactivated after 5s") } + + ev := waitForDatabaseUserDeactivateEvent(t, testCtx) + require.Equal(t, tc.teleportUser, ev.User) + require.Equal(t, tc.expectDatabaseUser, ev.DatabaseUser) }) } } @@ -451,6 +461,21 @@ func TestAutoUsersMongoDB(t *testing.T) { case <-time.After(5 * time.Second): t.Fatal("user not deactivated after 5s") } + + ev := waitForDatabaseUserDeactivateEvent(t, testCtx) + require.Equal(t, username, ev.User) + require.Equal(t, "alice", ev.DatabaseUser) }) } } + +func waitForDatabaseUserDeactivateEvent(t *testing.T, testCtx *testContext) *apievents.DatabaseUserDeactivate { + t.Helper() + const code = libevents.DatabaseSessionUserDeactivateCode + event := waitForEvent(t, testCtx, code) + require.Equal(t, code, event.GetCode()) + + ev, ok := event.(*apievents.DatabaseUserDeactivate) + require.True(t, ok) + return ev +} diff --git a/lib/srv/db/mysql/autousers.go b/lib/srv/db/mysql/autousers.go index 8460b1a831487..4c2d071d1c2f1 100644 --- a/lib/srv/db/mysql/autousers.go +++ b/lib/srv/db/mysql/autousers.go @@ -520,11 +520,11 @@ func doTransaction(conn *clientConn, do func() error) error { } func readDeleteUserResult(res *mysql.Result) string { - if len(res.Values) != 1 && len(res.Values[0]) != 1 { + if res == nil || res.Resultset == nil || + len(res.Resultset.Values) != 1 || len(res.Resultset.Values[0]) != 1 { return "" } - - return string(res.Values[0][0].AsString()) + return string(res.Resultset.Values[0][0].AsString()) } func getCreateProcedureCommand(conn *clientConn, procedureName string) (string, bool) { diff --git a/lib/srv/db/mysql/test.go b/lib/srv/db/mysql/test.go index 04c1751c07f03..7bf38c09b52cc 100644 --- a/lib/srv/db/mysql/test.go +++ b/lib/srv/db/mysql/test.go @@ -320,7 +320,7 @@ func (h *testHandler) HandleStmtPrepare(prepare string) (int, int, interface{}, return params, 0, nil, nil } func (h *testHandler) HandleStmtExecute(_ interface{}, query string, args []interface{}) (*mysql.Result, error) { - h.log.Debugf("Received execute %q with args %+v.", args) + h.log.Debugf("Received execute %q with args %+v.", query, args) if strings.HasPrefix(query, "CALL ") { return h.handleCallProcedure(query, args) } diff --git a/lib/srv/db/postgres/test.go b/lib/srv/db/postgres/test.go index c04d7d08ee5a6..2658b02901840 100644 --- a/lib/srv/db/postgres/test.go +++ b/lib/srv/db/postgres/test.go @@ -303,8 +303,12 @@ func (s *TestServer) handleStartup(client *pgproto3.Backend, startupMessage *pgp if err := s.handleActivateUser(client); err != nil { s.log.WithError(err).Error("Failed to handle user activation.") } - case deactivateQuery, deleteQuery: - if err := s.handleDeactivateUser(client); err != nil { + case deleteQuery: + if err := s.handleDeactivateUser(client, true); err != nil { + s.log.WithError(err).Error("Failed to handle user deletion.") + } + case deactivateQuery: + if err := s.handleDeactivateUser(client, false); err != nil { s.log.WithError(err).Error("Failed to handle user deactivation.") } case updatePermissionsQuery: @@ -587,7 +591,7 @@ func (s *TestServer) handleActivateUser(client *pgproto3.Backend) error { return nil } -func (s *TestServer) handleDeactivateUser(client *pgproto3.Backend) error { +func (s *TestServer) handleDeactivateUser(client *pgproto3.Backend, sendDeleteResponse bool) error { // Expect Describe message. _, err := s.receiveDescribeMessage(client) if err != nil { @@ -633,11 +637,23 @@ func (s *TestServer) handleDeactivateUser(client *pgproto3.Backend) error { return trace.Wrap(err) } // Respond to Bind message. - err = s.sendMessages(client, + messages := []pgproto3.BackendMessage{ &pgproto3.BindComplete{}, - &pgproto3.NoData{}, + } + if sendDeleteResponse { + messages = append(messages, + &pgproto3.RowDescription{Fields: TestDeleteUserResponse.FieldDescriptions}, + &pgproto3.DataRow{Values: TestDeleteUserResponse.Rows[0]}, + ) + } else { + messages = append(messages, &pgproto3.NoData{}) + } + messages = append(messages, &pgproto3.CommandComplete{}, - &pgproto3.ReadyForQuery{}) + &pgproto3.ReadyForQuery{}, + ) + + err = s.sendMessages(client, messages...) if err != nil { return trace.Wrap(err) } @@ -1012,6 +1028,13 @@ var TestQueryResponse = &pgconn.Result{ CommandTag: pgconn.CommandTag("select 1"), } +// TestDeleteUserResponse is the response test Postgres server sends to every +// query that calls the auto user deletion procedure. +var TestDeleteUserResponse = &pgconn.Result{ + FieldDescriptions: []pgproto3.FieldDescription{{Name: []byte("state")}}, + Rows: [][][]byte{{[]byte("TP003")}}, +} + // TestLongRunningQuery is a stub SQL query clients can use to simulate a long // running query that can be only be stopped by a cancel request. const TestLongRunningQuery = "pg_sleep(forever)"