Skip to content

Commit 6a1e536

Browse files
committed
Merge branch 'fix/253-auth-callback-race-condition' into bopen-master
2 parents 1791e19 + 43f7d32 commit 6a1e536

File tree

1 file changed

+103
-37
lines changed

1 file changed

+103
-37
lines changed

auth/peer.go

Lines changed: 103 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"errors"
1313
"fmt"
1414
"log/slog"
15+
"sync"
1516
"sync/atomic"
1617
"time"
1718

@@ -39,6 +40,12 @@ type OnCertificateReceivedCallback func(ctx context.Context, senderPublicKey *ec
3940
// The callback receives the sender's public key and the requested certificate set.
4041
type OnCertificateRequestReceivedCallback func(ctx context.Context, senderPublicKey *ec.PublicKey, requestedCertificates utils.RequestedCertificateSet) error
4142

43+
// InitialResponseCallback holds a callback function and associated session nonce for initial response handling.
44+
type InitialResponseCallback struct {
45+
Callback func(sessionNonce string) error
46+
SessionNonce string
47+
}
48+
4249
// Peer represents a peer capable of performing mutual authentication.
4350
// It manages sessions, handles authentication handshakes, certificate requests and responses,
4451
// and sending and receiving general messages over a transport layer.
@@ -51,14 +58,12 @@ type Peer struct {
5158
onGeneralMessageReceivedCallbacks map[int32]OnGeneralMessageReceivedCallback
5259
onCertificateReceivedCallbacks map[int32]OnCertificateReceivedCallback
5360
onCertificateRequestReceivedCallbacks map[int32]OnCertificateRequestReceivedCallback
54-
onInitialResponseReceivedCallbacks map[int32]struct {
55-
Callback func(sessionNonce string) error
56-
SessionNonce string
57-
}
58-
callbackIdCounter atomic.Int32
59-
autoPersistLastSession bool
60-
lastInteractedWithPeer *ec.PublicKey
61-
logger *slog.Logger // Logger for debug messages
61+
onInitialResponseReceivedCallbacks map[int32]InitialResponseCallback
62+
callbacksMu sync.RWMutex
63+
callbackIdCounter atomic.Int32
64+
autoPersistLastSession bool
65+
lastInteractedWithPeer *ec.PublicKey
66+
logger *slog.Logger // Logger for debug messages
6267
}
6368

6469
// PeerOptions contains configuration options for creating a new Peer instance.
@@ -80,11 +85,8 @@ func NewPeer(cfg *PeerOptions) *Peer {
8085
onGeneralMessageReceivedCallbacks: make(map[int32]OnGeneralMessageReceivedCallback),
8186
onCertificateReceivedCallbacks: make(map[int32]OnCertificateReceivedCallback),
8287
onCertificateRequestReceivedCallbacks: make(map[int32]OnCertificateRequestReceivedCallback),
83-
onInitialResponseReceivedCallbacks: make(map[int32]struct {
84-
Callback func(sessionNonce string) error
85-
SessionNonce string
86-
}),
87-
logger: cfg.Logger,
88+
onInitialResponseReceivedCallbacks: make(map[int32]InitialResponseCallback),
89+
logger: cfg.Logger,
8890
}
8991

9092
// Use default logger if none provided
@@ -154,37 +156,67 @@ func (p *Peer) Stop() error {
154156
// ListenForGeneralMessages registers a callback for general messages
155157
func (p *Peer) ListenForGeneralMessages(callback OnGeneralMessageReceivedCallback) int32 {
156158
callbackID := p.callbackIdCounter.Add(1)
159+
p.callbacksMu.Lock()
157160
p.onGeneralMessageReceivedCallbacks[callbackID] = callback
161+
p.callbacksMu.Unlock()
158162
return callbackID
159163
}
160164

161165
// StopListeningForGeneralMessages removes a general message listener
162166
func (p *Peer) StopListeningForGeneralMessages(callbackID int32) {
167+
p.callbacksMu.Lock()
163168
delete(p.onGeneralMessageReceivedCallbacks, callbackID)
169+
p.callbacksMu.Unlock()
164170
}
165171

166172
// ListenForCertificatesReceived registers a callback for certificate reception
167173
func (p *Peer) ListenForCertificatesReceived(callback OnCertificateReceivedCallback) int32 {
168174
callbackID := p.callbackIdCounter.Add(1)
175+
p.callbacksMu.Lock()
169176
p.onCertificateReceivedCallbacks[callbackID] = callback
177+
p.callbacksMu.Unlock()
170178
return callbackID
171179
}
172180

173181
// StopListeningForCertificatesReceived removes a certificate reception listener
174182
func (p *Peer) StopListeningForCertificatesReceived(callbackID int32) {
183+
p.callbacksMu.Lock()
175184
delete(p.onCertificateReceivedCallbacks, callbackID)
185+
p.callbacksMu.Unlock()
176186
}
177187

178188
// ListenForCertificatesRequested registers a callback for certificate requests
179189
func (p *Peer) ListenForCertificatesRequested(callback OnCertificateRequestReceivedCallback) int32 {
180190
callbackID := p.callbackIdCounter.Add(1)
191+
p.callbacksMu.Lock()
181192
p.onCertificateRequestReceivedCallbacks[callbackID] = callback
193+
p.callbacksMu.Unlock()
182194
return callbackID
183195
}
184196

185197
// StopListeningForCertificatesRequested removes a certificate request listener
186198
func (p *Peer) StopListeningForCertificatesRequested(callbackID int32) {
199+
p.callbacksMu.Lock()
187200
delete(p.onCertificateRequestReceivedCallbacks, callbackID)
201+
p.callbacksMu.Unlock()
202+
}
203+
204+
// StopListeningForInitialResponse removes a certificate initial response listener
205+
func (p *Peer) StopListeningForInitialResponse(callbackID int32) {
206+
p.callbacksMu.Lock()
207+
defer p.callbacksMu.Unlock()
208+
delete(p.onInitialResponseReceivedCallbacks, callbackID)
209+
}
210+
211+
// getInitialResponseCallbacks retrieves the initial response callbacks
212+
func (p *Peer) getInitialResponseCallbacks() map[int32]InitialResponseCallback {
213+
p.callbacksMu.RLock()
214+
defer p.callbacksMu.RUnlock()
215+
callbacks := make(map[int32]InitialResponseCallback)
216+
for k, v := range p.onInitialResponseReceivedCallbacks {
217+
callbacks[k] = v
218+
}
219+
return callbacks
188220
}
189221

190222
// ToPeer sends a message to a peer, initiating authentication if needed
@@ -335,16 +367,15 @@ func (p *Peer) initiateHandshake(ctx context.Context, peerIdentityKey *ec.Public
335367
// Register a callback for the response
336368
callbackID := p.callbackIdCounter.Add(1)
337369

338-
p.onInitialResponseReceivedCallbacks[callbackID] = struct {
339-
Callback func(sessionNonce string) error
340-
SessionNonce string
341-
}{
370+
p.callbacksMu.Lock()
371+
p.onInitialResponseReceivedCallbacks[callbackID] = InitialResponseCallback{
342372
Callback: func(peerNonce string) error {
343373
responseChan <- struct{}{}
344374
return nil
345375
},
346376
SessionNonce: sessionNonce,
347377
}
378+
p.callbacksMu.Unlock()
348379

349380
// TODO: replace maxWait with simply context with timeout
350381
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(maxWaitTimeMs)*time.Millisecond)
@@ -354,18 +385,18 @@ func (p *Peer) initiateHandshake(ctx context.Context, peerIdentityKey *ec.Public
354385
err = p.transport.Send(ctx, initialRequest)
355386
if err != nil {
356387
close(responseChan)
357-
delete(p.onInitialResponseReceivedCallbacks, callbackID)
388+
p.StopListeningForInitialResponse(callbackID)
358389
return nil, NewAuthError("failed to send initial request", err)
359390
}
360391

361392
// Wait for response or timeout
362393
select {
363394
case <-responseChan:
364395
close(responseChan)
365-
delete(p.onInitialResponseReceivedCallbacks, callbackID)
396+
p.StopListeningForInitialResponse(callbackID)
366397
return session, nil
367398
case <-ctxWithTimeout.Done():
368-
delete(p.onInitialResponseReceivedCallbacks, callbackID)
399+
p.StopListeningForInitialResponse(callbackID)
369400
return nil, ErrTimeout
370401
}
371402
}
@@ -619,7 +650,14 @@ func (p *Peer) handleInitialResponse(ctx context.Context, message *AuthMessage,
619650
// Certificates validated successfully, authenticate the session
620651
session.IsAuthenticated = true
621652

653+
p.callbacksMu.RLock()
654+
callbacks := make([]OnCertificateReceivedCallback, 0, len(p.onCertificateReceivedCallbacks))
622655
for _, callback := range p.onCertificateReceivedCallbacks {
656+
callbacks = append(callbacks, callback)
657+
}
658+
p.callbacksMu.RUnlock()
659+
660+
for _, callback := range callbacks {
623661
err := callback(ctx, senderPublicKey, message.Certificates)
624662
if err != nil {
625663
return NewAuthError("certificate received callback error", err)
@@ -634,11 +672,11 @@ func (p *Peer) handleInitialResponse(ctx context.Context, message *AuthMessage,
634672

635673
p.lastInteractedWithPeer = message.IdentityKey
636674

637-
for id, callback := range p.onInitialResponseReceivedCallbacks {
675+
for id, callback := range p.getInitialResponseCallbacks() {
638676
if callback.SessionNonce == session.SessionNonce {
639677
// Call the initial response callback with the peer's nonce
640678
err := callback.Callback(session.SessionNonce)
641-
delete(p.onInitialResponseReceivedCallbacks, id)
679+
p.StopListeningForInitialResponse(id)
642680
if err != nil {
643681
return NewAuthError("initial response received callback error", err)
644682
}
@@ -657,16 +695,24 @@ func (p *Peer) handleInitialResponse(ctx context.Context, message *AuthMessage,
657695
}
658696

659697
func (p *Peer) sendCertificates(ctx context.Context, message *AuthMessage) error {
660-
if len(p.onCertificateRequestReceivedCallbacks) > 0 {
698+
p.callbacksMu.RLock()
699+
hasCallbacks := len(p.onCertificateRequestReceivedCallbacks) > 0
700+
if hasCallbacks {
701+
callbacks := make([]OnCertificateRequestReceivedCallback, 0, len(p.onCertificateRequestReceivedCallbacks))
661702
for _, callback := range p.onCertificateRequestReceivedCallbacks {
703+
callbacks = append(callbacks, callback)
704+
}
705+
p.callbacksMu.RUnlock()
706+
707+
for _, callback := range callbacks {
662708
err := callback(ctx, message.IdentityKey, message.RequestedCertificates)
663709
if err != nil {
664-
// Log callback error but continue
665710
return fmt.Errorf("on certificate request callback failed: %w", err)
666711
}
667712
}
668713
return nil
669714
}
715+
p.callbacksMu.RUnlock()
670716

671717
certs, err := utils.GetVerifiableCertificates(
672718
ctx,
@@ -690,12 +736,6 @@ func (p *Peer) sendCertificates(ctx context.Context, message *AuthMessage) error
690736

691737
// handleCertificateRequest processes a certificate request message
692738
func (p *Peer) handleCertificateRequest(ctx context.Context, message *AuthMessage, senderPublicKey *ec.PublicKey) error {
693-
// Validate the session exists and is authenticated
694-
session, err := p.sessionManager.GetSession(senderPublicKey.ToDERHex())
695-
if err != nil || session == nil {
696-
return ErrSessionNotFound
697-
}
698-
699739
valid, err := utils.VerifyNonce(ctx, message.YourNonce, p.wallet, wallet.Counterparty{Type: wallet.CounterpartyTypeSelf})
700740
if err != nil {
701741
return fmt.Errorf("failed to validate nonce: %w", err)
@@ -704,6 +744,14 @@ func (p *Peer) handleCertificateRequest(ctx context.Context, message *AuthMessag
704744
return ErrInvalidNonce
705745
}
706746

747+
// Validate the session exists and is authenticated
748+
// Use YourNonce to look up the session, which uniquely identifies the correct session
749+
// even when multiple devices share the same identity key
750+
session, err := p.sessionManager.GetSession(message.YourNonce)
751+
if err != nil || session == nil {
752+
return ErrSessionNotFound
753+
}
754+
707755
// Update session timestamp
708756
session.LastUpdate = time.Now().UnixMilli()
709757
p.sessionManager.UpdateSession(session)
@@ -755,12 +803,6 @@ func (p *Peer) handleCertificateRequest(ctx context.Context, message *AuthMessag
755803

756804
// handleCertificateResponse processes a certificate response message
757805
func (p *Peer) handleCertificateResponse(ctx context.Context, message *AuthMessage, senderPublicKey *ec.PublicKey) error {
758-
// Validate the session exists and is authenticated
759-
session, err := p.sessionManager.GetSession(senderPublicKey.ToDERHex())
760-
if err != nil || session == nil {
761-
return ErrSessionNotFound
762-
}
763-
764806
valid, err := utils.VerifyNonce(ctx, message.YourNonce, p.wallet, wallet.Counterparty{Type: wallet.CounterpartyTypeSelf})
765807
if err != nil {
766808
return fmt.Errorf("failed to validate nonce: %w", err)
@@ -769,6 +811,14 @@ func (p *Peer) handleCertificateResponse(ctx context.Context, message *AuthMessa
769811
return ErrInvalidNonce
770812
}
771813

814+
// Validate the session exists and is authenticated
815+
// Use YourNonce to look up the session, which uniquely identifies the correct session
816+
// even when multiple devices share the same identity key
817+
session, err := p.sessionManager.GetSession(message.YourNonce)
818+
if err != nil || session == nil {
819+
return ErrSessionNotFound
820+
}
821+
772822
// Update session timestamp
773823
session.LastUpdate = time.Now().UnixMilli()
774824
p.sessionManager.UpdateSession(session)
@@ -846,7 +896,14 @@ func (p *Peer) handleCertificateResponse(ctx context.Context, message *AuthMessa
846896

847897
// TODO: maybe it should by default (if no callback) check if there are all required certificates
848898
// Notify certificate listeners
899+
p.callbacksMu.RLock()
900+
callbacks := make([]OnCertificateReceivedCallback, 0, len(p.onCertificateReceivedCallbacks))
849901
for _, callback := range p.onCertificateReceivedCallbacks {
902+
callbacks = append(callbacks, callback)
903+
}
904+
p.callbacksMu.RUnlock()
905+
906+
for _, callback := range callbacks {
850907
err := callback(ctx, senderPublicKey, message.Certificates)
851908
if err != nil {
852909
return fmt.Errorf("certificate received callback error: %w", err)
@@ -868,7 +925,9 @@ func (p *Peer) handleGeneralMessage(ctx context.Context, message *AuthMessage, s
868925
}
869926

870927
// Validate the session exists and is authenticated
871-
session, err := p.sessionManager.GetSession(senderPublicKey.ToDERHex())
928+
// Use YourNonce to look up the session, which uniquely identifies the correct session
929+
// even when multiple devices share the same identity key
930+
session, err := p.sessionManager.GetSession(message.YourNonce)
872931
if err != nil || session == nil {
873932
return ErrSessionNotFound
874933
}
@@ -918,7 +977,14 @@ func (p *Peer) handleGeneralMessage(ctx context.Context, message *AuthMessage, s
918977
}
919978

920979
// Notify general message listeners
980+
p.callbacksMu.RLock()
981+
callbacks := make([]OnGeneralMessageReceivedCallback, 0, len(p.onGeneralMessageReceivedCallbacks))
921982
for _, callback := range p.onGeneralMessageReceivedCallbacks {
983+
callbacks = append(callbacks, callback)
984+
}
985+
p.callbacksMu.RUnlock()
986+
987+
for _, callback := range callbacks {
922988
err := callback(ctx, senderPublicKey, message.Payload)
923989
if err != nil {
924990
// Log callback error but continue

0 commit comments

Comments
 (0)