From 37b315d634957e46edfc476829b493a976d11080 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Wed, 24 Jan 2024 11:08:04 -0500 Subject: [PATCH] Fix an issue selecting MySQL database is not reflected in the audit logs --- lib/srv/db/audit_test.go | 23 ++++++++++++++++++++++- lib/srv/db/mysql/engine.go | 6 ++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/lib/srv/db/audit_test.go b/lib/srv/db/audit_test.go index db3ac8249eedc..65dc6802b7f38 100644 --- a/lib/srv/db/audit_test.go +++ b/lib/srv/db/audit_test.go @@ -167,7 +167,17 @@ func TestAuditMySQL(t *testing.T) { // Simple query should trigger the query event. _, err = mysql.Execute("select 1") require.NoError(t, err) - requireQueryEvent(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1") + requireQueryEventWithDBName(t, testCtx, libevents.DatabaseSessionQueryCode, "select 1", "") + + // Switch to another database. + err = mysql.UseDB("foo") + require.NoError(t, err) + requireEvent(t, testCtx, libevents.MySQLInitDBCode) + + // Check DatabaseName is updated. + _, err = mysql.Execute("select 2") + require.NoError(t, err) + requireQueryEventWithDBName(t, testCtx, libevents.DatabaseSessionQueryCode, "select 2", "foo") // Closing connection should trigger session end event. err = mysql.Close() @@ -373,6 +383,17 @@ func requireQueryEvent(t *testing.T, testCtx *testContext, code, query string) { require.Equal(t, query, event.(*events.DatabaseSessionQuery).DatabaseQuery) } +func requireQueryEventWithDBName(t *testing.T, testCtx *testContext, code, query, dbName string) { + t.Helper() + event := waitForAnyEvent(t, testCtx) + require.Equal(t, code, event.GetCode()) + + queryEvent, ok := event.(*events.DatabaseSessionQuery) + require.True(t, ok) + require.Equal(t, query, queryEvent.DatabaseQuery) + require.Equal(t, dbName, queryEvent.DatabaseName) +} + func waitForAnyEvent(t *testing.T, testCtx *testContext) events.AuditEvent { t.Helper() select { diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index a71d7aaec3dbd..94911d3579609 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -362,6 +362,12 @@ func (e *Engine) receiveFromClient(clientConn, serverConn net.Conn, clientErrCh return case *protocol.InitDB: + // Update DatabaseName when switching to another so the audit logs + // are up to date. E.g.: + // mysql> use foo; + // mysql> select * from users; + sessionCtx.DatabaseName = pkt.SchemaName() + e.Audit.EmitEvent(e.Context, makeInitDBEvent(sessionCtx, pkt)) case *protocol.CreateDB: e.Audit.EmitEvent(e.Context, makeCreateDBEvent(sessionCtx, pkt))