Skip to content

Commit aba7b01

Browse files
committed
feat: KIP-368 support periodic re-auth
Allow SASL Connections to Periodically Re-Authenticate [KIP-368](https://cwiki.apache.org/confluence/display/KAFKA/KIP-368%3A+Allow+SASL+Connections+to+Periodically+Re-Authenticate)
1 parent 1776783 commit aba7b01

7 files changed

+298
-34
lines changed

broker.go

+56-16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"fmt"
88
"io"
9+
"math/rand"
910
"net"
1011
"sort"
1112
"strconv"
@@ -52,7 +53,8 @@ type Broker struct {
5253
brokerRequestsInFlight metrics.Counter
5354
brokerThrottleTime metrics.Histogram
5455

55-
kerberosAuthenticator GSSAPIKerberosAuth
56+
kerberosAuthenticator GSSAPIKerberosAuth
57+
clientSessionReauthenticationTimeMs int64
5658
}
5759

5860
// SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
@@ -923,6 +925,13 @@ func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) erro
923925
return ErrNotConnected
924926
}
925927

928+
if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs {
929+
err := b.authenticateViaSASL()
930+
if err != nil {
931+
return err
932+
}
933+
}
934+
926935
if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
927936
return ErrUnsupportedVersion
928937
}
@@ -1263,7 +1272,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
12631272

12641273
// Will be decremented in updateIncomingCommunicationMetrics (except error)
12651274
b.addRequestInFlightMetrics(1)
1266-
bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)
1275+
bytesWritten, resVersion, err := b.sendSASLPlainAuthClientResponse(correlationID)
12671276
b.updateOutgoingCommunicationMetrics(bytesWritten)
12681277

12691278
if err != nil {
@@ -1274,7 +1283,8 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
12741283

12751284
b.correlationID++
12761285

1277-
bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
1286+
res := &SaslAuthenticateResponse{}
1287+
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)
12781288
b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))
12791289

12801290
// With v1 sasl we get an error message set in the response we can return
@@ -1288,6 +1298,10 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
12881298
return nil
12891299
}
12901300

1301+
func currentUnixMilli() int64 {
1302+
return time.Now().UnixNano() / int64(time.Millisecond)
1303+
}
1304+
12911305
// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
12921306
// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
12931307
func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
@@ -1327,7 +1341,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
13271341
b.addRequestInFlightMetrics(1)
13281342
correlationID := b.correlationID
13291343

1330-
bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
1344+
bytesWritten, resVersion, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
13311345
b.updateOutgoingCommunicationMetrics(bytesWritten)
13321346
if err != nil {
13331347
b.addRequestInFlightMetrics(-1)
@@ -1337,7 +1351,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
13371351
b.correlationID++
13381352

13391353
res := &SaslAuthenticateResponse{}
1340-
bytesRead, err := b.receiveSASLServerResponse(res, correlationID)
1354+
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)
13411355

13421356
requestLatency := time.Since(requestTime)
13431357
b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
@@ -1464,7 +1478,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
14641478
}
14651479

14661480
func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
1467-
rb := &SaslAuthenticateRequest{msg}
1481+
rb := b.createSaslAuthenticateRequest(msg)
14681482
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
14691483
buf, err := encode(req, b.conf.MetricRegistry)
14701484
if err != nil {
@@ -1474,6 +1488,15 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i
14741488
return b.write(buf)
14751489
}
14761490

1491+
func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest {
1492+
authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg}
1493+
if b.conf.Version.IsAtLeast(V2_2_0_0) {
1494+
authenticateRequest.Version = 1
1495+
}
1496+
1497+
return &authenticateRequest
1498+
}
1499+
14771500
func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
14781501
buf := make([]byte, responseLengthSize+correlationIDSize)
14791502
_, err := b.readFull(buf)
@@ -1538,32 +1561,34 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
15381561
return strings.Join(buf, elemSep)
15391562
}
15401563

1541-
func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
1564+
func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, int16, error) {
15421565
authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
1543-
rb := &SaslAuthenticateRequest{authBytes}
1566+
rb := b.createSaslAuthenticateRequest(authBytes)
15441567
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
15451568
buf, err := encode(req, b.conf.MetricRegistry)
15461569
if err != nil {
1547-
return 0, err
1570+
return 0, rb.Version, err
15481571
}
15491572

1550-
return b.write(buf)
1573+
write, err := b.write(buf)
1574+
return write, rb.Version, err
15511575
}
15521576

1553-
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
1554-
rb := &SaslAuthenticateRequest{initialResp}
1577+
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, int16, error) {
1578+
rb := b.createSaslAuthenticateRequest(initialResp)
15551579

15561580
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
15571581

15581582
buf, err := encode(req, b.conf.MetricRegistry)
15591583
if err != nil {
1560-
return 0, err
1584+
return 0, rb.version(), err
15611585
}
15621586

1563-
return b.write(buf)
1587+
write, err := b.write(buf)
1588+
return write, rb.version(), err
15641589
}
15651590

1566-
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {
1591+
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32, resVersion int16) (int, error) {
15671592
buf := make([]byte, responseLengthSize+correlationIDSize)
15681593
bytesRead, err := b.readFull(buf)
15691594
if err != nil {
@@ -1587,7 +1612,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
15871612
return bytesRead, err
15881613
}
15891614

1590-
if err := versionedDecode(buf, res, 0); err != nil {
1615+
if err := versionedDecode(buf, res, resVersion); err != nil {
15911616
return bytesRead, err
15921617
}
15931618

@@ -1599,6 +1624,21 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
15991624
return bytesRead, err
16001625
}
16011626

1627+
if res.SessionLifetimeMs > 0 {
1628+
// Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes
1629+
// pick a random percentage between 85% and 95% for session re-authentication
1630+
positiveSessionLifetimeMs := res.SessionLifetimeMs
1631+
authenticationEndMs := currentUnixMilli()
1632+
pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85
1633+
pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10
1634+
pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously
1635+
sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse)
1636+
DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse)
1637+
b.clientSessionReauthenticationTimeMs = authenticationEndMs + sessionLifetimeMsToUse
1638+
} else {
1639+
b.clientSessionReauthenticationTimeMs = 0
1640+
}
1641+
16021642
return bytesRead, nil
16031643
}
16041644

broker_test.go

+159
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,165 @@ func TestBuildClientFirstMessage(t *testing.T) {
828828
}
829829
}
830830

831+
func TestKip368ReAuthenticationSuccess(t *testing.T) {
832+
sessionLifetimeMs := int64(100)
833+
834+
mockBroker := NewMockBroker(t, 0)
835+
836+
countSaslAuthRequests := func() (count int) {
837+
for _, rr := range mockBroker.History() {
838+
switch rr.Request.(type) {
839+
case *SaslAuthenticateRequest:
840+
count++
841+
}
842+
}
843+
return
844+
}
845+
846+
mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
847+
SetAuthBytes([]byte(`response_payload`)).
848+
SetSessionLifetimeMs(sessionLifetimeMs)
849+
850+
mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
851+
SetEnabledMechanisms([]string{SASLTypePlaintext})
852+
853+
mockApiVersions := NewMockApiVersionsResponse(t)
854+
855+
mockBroker.SetHandlerByMap(map[string]MockResponse{
856+
"SaslAuthenticateRequest": mockSASLAuthResponse,
857+
"SaslHandshakeRequest": mockSASLHandshakeResponse,
858+
"ApiVersionsRequest": mockApiVersions,
859+
})
860+
861+
broker := NewBroker(mockBroker.Addr())
862+
863+
conf := NewTestConfig()
864+
conf.Net.SASL.Enable = true
865+
conf.Net.SASL.Mechanism = SASLTypePlaintext
866+
conf.Net.SASL.Version = SASLHandshakeV1
867+
conf.Net.SASL.AuthIdentity = "authid"
868+
conf.Net.SASL.User = "token"
869+
conf.Net.SASL.Password = "password"
870+
871+
broker.conf = conf
872+
broker.conf.Version = V2_2_0_0
873+
874+
err := broker.Open(conf)
875+
if err != nil {
876+
t.Fatal(err)
877+
}
878+
t.Cleanup(func() { _ = broker.Close() })
879+
880+
connected, err := broker.Connected()
881+
if err != nil || !connected {
882+
t.Fatal(err)
883+
}
884+
885+
actualSaslAuthRequests := countSaslAuthRequests()
886+
if actualSaslAuthRequests != 1 {
887+
t.Fatalf("unexpected number of SaslAuthRequests during initial authentication: %d", actualSaslAuthRequests)
888+
}
889+
890+
timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)
891+
892+
loop:
893+
for actualSaslAuthRequests < 2 {
894+
select {
895+
case <-timeout:
896+
break loop
897+
default:
898+
time.Sleep(10 * time.Millisecond)
899+
// put some traffic on the wire
900+
_, err = broker.ApiVersions(&ApiVersionsRequest{})
901+
if err != nil {
902+
t.Fatal(err)
903+
}
904+
actualSaslAuthRequests = countSaslAuthRequests()
905+
}
906+
}
907+
908+
if actualSaslAuthRequests < 2 {
909+
t.Fatalf("sasl reauth has not occurred within expected timeframe")
910+
}
911+
912+
mockBroker.Close()
913+
}
914+
915+
func TestKip368ReAuthenticationFailure(t *testing.T) {
916+
sessionLifetimeMs := int64(100)
917+
918+
mockBroker := NewMockBroker(t, 0)
919+
920+
mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
921+
SetAuthBytes([]byte(`response_payload`)).
922+
SetSessionLifetimeMs(sessionLifetimeMs)
923+
924+
mockSASLAuthErrorResponse := NewMockSaslAuthenticateResponse(t).
925+
SetError(ErrSASLAuthenticationFailed)
926+
927+
mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
928+
SetEnabledMechanisms([]string{SASLTypePlaintext})
929+
930+
mockApiVersions := NewMockApiVersionsResponse(t)
931+
932+
mockBroker.SetHandlerByMap(map[string]MockResponse{
933+
"SaslAuthenticateRequest": mockSASLAuthResponse,
934+
"SaslHandshakeRequest": mockSASLHandshakeResponse,
935+
"ApiVersionsRequest": mockApiVersions,
936+
})
937+
938+
broker := NewBroker(mockBroker.Addr())
939+
940+
conf := NewTestConfig()
941+
conf.Net.SASL.Enable = true
942+
conf.Net.SASL.Mechanism = SASLTypePlaintext
943+
conf.Net.SASL.Version = SASLHandshakeV1
944+
conf.Net.SASL.AuthIdentity = "authid"
945+
conf.Net.SASL.User = "token"
946+
conf.Net.SASL.Password = "password"
947+
948+
broker.conf = conf
949+
broker.conf.Version = V2_2_0_0
950+
951+
err := broker.Open(conf)
952+
if err != nil {
953+
t.Fatal(err)
954+
}
955+
t.Cleanup(func() { _ = broker.Close() })
956+
957+
connected, err := broker.Connected()
958+
if err != nil || !connected {
959+
t.Fatal(err)
960+
}
961+
962+
mockBroker.SetHandlerByMap(map[string]MockResponse{
963+
"SaslAuthenticateRequest": mockSASLAuthErrorResponse,
964+
"SaslHandshakeRequest": mockSASLHandshakeResponse,
965+
"ApiVersionsRequest": mockApiVersions,
966+
})
967+
968+
timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)
969+
970+
var apiVersionError error
971+
loop:
972+
for apiVersionError == nil {
973+
select {
974+
case <-timeout:
975+
break loop
976+
default:
977+
time.Sleep(10 * time.Millisecond)
978+
// put some traffic on the wire
979+
_, apiVersionError = broker.ApiVersions(&ApiVersionsRequest{})
980+
}
981+
}
982+
983+
if !errors.Is(apiVersionError, ErrSASLAuthenticationFailed) {
984+
t.Fatalf("sasl reauth has not failed in the expected way %v", apiVersionError)
985+
}
986+
987+
mockBroker.Close()
988+
}
989+
831990
// We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake
832991
var brokerTestTable = []struct {
833992
version KafkaVersion

mockresponses.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -1057,19 +1057,23 @@ func (mr *MockListAclsResponse) For(reqBody versionedDecoder) encoderWithHeader
10571057
}
10581058

10591059
type MockSaslAuthenticateResponse struct {
1060-
t TestReporter
1061-
kerror KError
1062-
saslAuthBytes []byte
1060+
t TestReporter
1061+
kerror KError
1062+
saslAuthBytes []byte
1063+
sessionLifetimeMs int64
10631064
}
10641065

10651066
func NewMockSaslAuthenticateResponse(t TestReporter) *MockSaslAuthenticateResponse {
10661067
return &MockSaslAuthenticateResponse{t: t}
10671068
}
10681069

10691070
func (msar *MockSaslAuthenticateResponse) For(reqBody versionedDecoder) encoderWithHeader {
1071+
req := reqBody.(*SaslAuthenticateRequest)
10701072
res := &SaslAuthenticateResponse{}
1073+
res.Version = req.Version
10711074
res.Err = msar.kerror
10721075
res.SaslAuthBytes = msar.saslAuthBytes
1076+
res.SessionLifetimeMs = msar.sessionLifetimeMs
10731077
return res
10741078
}
10751079

@@ -1083,6 +1087,11 @@ func (msar *MockSaslAuthenticateResponse) SetAuthBytes(saslAuthBytes []byte) *Mo
10831087
return msar
10841088
}
10851089

1090+
func (msar *MockSaslAuthenticateResponse) SetSessionLifetimeMs(sessionLifetimeMs int64) *MockSaslAuthenticateResponse {
1091+
msar.sessionLifetimeMs = sessionLifetimeMs
1092+
return msar
1093+
}
1094+
10861095
type MockDeleteAclsResponse struct {
10871096
t TestReporter
10881097
}

0 commit comments

Comments
 (0)