diff --git a/lib/srv/db/audit_test.go b/lib/srv/db/audit_test.go index 563005112f5cf..2908c30dbf99b 100644 --- a/lib/srv/db/audit_test.go +++ b/lib/srv/db/audit_test.go @@ -75,6 +75,68 @@ func TestAuditPostgres(t *testing.T) { requireEvent(t, testCtx, libevents.PostgresBindCode) requireEvent(t, testCtx, libevents.PostgresExecuteCode) + bindTests := []struct { + desc string + sql string + params [][]byte + formatCodes []int16 + wantParams []string + }{ + { + desc: "zero format codes applies text format to all params", + sql: "select $1, $2", + params: [][]byte{[]byte("fish"), []byte("cat")}, + wantParams: []string{"fish", "cat"}, + }, + { + desc: "one text format codes applies text format to all params", + sql: "select $1, $2", + params: [][]byte{[]byte("fish"), []byte("cat")}, + formatCodes: []int16{0}, // text format. + wantParams: []string{"fish", "cat"}, + }, + { + desc: "one binary format codes applies binary format to all params", + sql: "select $1, $2", + params: [][]byte{[]byte("fish"), []byte("cat")}, + formatCodes: []int16{1}, // binary format. + // event should encode binary as base64 strings. + wantParams: []string{"ZmlzaA==", "Y2F0"}, + }, + { + desc: "apply corresponding format code to each param", + sql: "select $1, $2, $3", + params: [][]byte{[]byte("fish"), []byte("cat"), []byte("dog")}, + formatCodes: []int16{1, 0, 0}, // binary, text, text format. + wantParams: []string{"ZmlzaA==", "cat", "dog"}, + }, + { + desc: "more than one format codes but fewer than params is invalid bind", + sql: "select $1, $2, $3", + params: [][]byte{[]byte("fish"), []byte("cat"), []byte("dog")}, + formatCodes: []int16{1, 0}, // binary, text. + wantParams: nil, // don't log params for invalid bind. + }, + { + desc: "more format codes than params is invalid bind", + sql: "select $1, $2", + params: [][]byte{[]byte("fish"), []byte("cat")}, + formatCodes: []int16{1, 0, 0}, // binary, text, text(missing) + wantParams: nil, // don't log params for invalid bind. + }, + } + for _, test := range bindTests { + t.Run(test.desc, func(t *testing.T) { + resultUnnamed := psql.ExecParams(ctx, test.sql, test.params, nil, test.formatCodes, nil).Read() + require.NotNil(t, resultUnnamed) + require.NoError(t, resultUnnamed.Err) + requireEvent(t, testCtx, libevents.PostgresParseCode) + event := requireBindEvent(t, testCtx) + require.Equal(t, test.wantParams, event.Parameters) + requireEvent(t, testCtx, libevents.PostgresExecuteCode) + }) + } + // Closing connection should trigger session end event. err = psql.Close(ctx) require.NoError(t, err) @@ -280,23 +342,37 @@ func TestAuditClickHouseHTTP(t *testing.T) { } func assertDatabaseQueryFromAuditEvent(t *testing.T, event events.AuditEvent, wantQuery string) { + t.Helper() query, ok := event.(*events.DatabaseSessionQuery) require.True(t, ok) require.Equal(t, wantQuery, query.DatabaseQuery) } -func requireEvent(t *testing.T, testCtx *testContext, code string) { +func requireBindEvent(t *testing.T, testCtx *testContext) *events.PostgresBind { + t.Helper() + event := requireEvent(t, testCtx, libevents.PostgresBindCode) + bindEvent, ok := event.(*events.PostgresBind) + require.True(t, ok) + require.NotNil(t, bindEvent) + return bindEvent +} + +func requireEvent(t *testing.T, testCtx *testContext, code string) events.AuditEvent { + t.Helper() event := waitForAnyEvent(t, testCtx) require.Equal(t, code, event.GetCode()) + return event } func requireQueryEvent(t *testing.T, testCtx *testContext, code, query string) { + t.Helper() event := waitForAnyEvent(t, testCtx) require.Equal(t, code, event.GetCode()) require.Equal(t, query, event.(*events.DatabaseSessionQuery).DatabaseQuery) } func waitForAnyEvent(t *testing.T, testCtx *testContext) events.AuditEvent { + t.Helper() select { case event := <-testCtx.emitter.C(): return event @@ -308,6 +384,7 @@ func waitForAnyEvent(t *testing.T, testCtx *testContext) events.AuditEvent { // waitForEvent waits for particular event code ignoring other events. func waitForEvent(t *testing.T, testCtx *testContext, code string) events.AuditEvent { + t.Helper() for { select { case event := <-testCtx.emitter.C(): diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index f6d3d1dd50740..e68979c6b137c 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -19,6 +19,7 @@ package postgres import ( "context" "crypto/tls" + "encoding/base64" "errors" "fmt" "io" @@ -603,9 +604,12 @@ func formatParameters(parameters [][]byte, formatCodes []int16) (formatted []str // by "parameter format codes" in the Bind message (0 - text, 1 - binary). // // Be a bit paranoid and make sure that number of format codes matches the - // number of parameters, or there are no format codes in which case all - // parameters will be text. - if len(formatCodes) != 0 && len(formatCodes) != len(parameters) { + // number of parameters, or there are zero or one format codes. + // zero format codes applies text format to all params. + // one format code applies the same format code to all params. + // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BIND + // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-FUNCTIONCALL + if len(formatCodes) > 1 && len(formatCodes) != len(parameters) { logrus.Warnf("Postgres parameter format codes and parameters don't match: %#v %#v.", parameters, formatCodes) return formatted @@ -614,23 +618,30 @@ func formatParameters(parameters [][]byte, formatCodes []int16) (formatted []str // According to Bind message documentation, if there are no parameter // format codes, it may mean that either there are no parameters, or // that all parameters use default text format. - if len(formatCodes) == 0 { - formatted = append(formatted, string(p)) - continue + var formatCode int16 + switch len(formatCodes) { + case 0: + // use default 0 (text) format for all params. + case 1: + // apply the same format code to all params. + formatCode = formatCodes[0] + default: + // apply format code corresponding to this param. + formatCode = formatCodes[i] } - switch formatCodes[i] { + + switch formatCode { case parameterFormatCodeText: // Text parameters can just be converted to their string // representation. formatted = append(formatted, string(p)) case parameterFormatCodeBinary: - // For binary parameters, just put a placeholder to avoid - // spamming the audit log with unreadable info. - formatted = append(formatted, "") + // For binary parameters, encode the parameter as a base64 string. + formatted = append(formatted, base64.StdEncoding.EncodeToString(p)) default: // Should never happen but... logrus.Warnf("Unknown Postgres parameter format code: %#v.", - formatCodes[i]) + formatCode) formatted = append(formatted, "") } }