@@ -12,6 +12,7 @@ import (
1212 "errors"
1313 "fmt"
1414 "log/slog"
15+ "sync"
1516 "sync/atomic"
1617 "time"
1718
@@ -39,6 +40,12 @@ type OnCertificateReceivedCallback func(ctx context.Context, senderPublicKey *ec
3940// The callback receives the sender's public key and the requested certificate set.
4041type OnCertificateRequestReceivedCallback func (ctx context.Context , senderPublicKey * ec.PublicKey , requestedCertificates utils.RequestedCertificateSet ) error
4142
43+ // InitialResponseCallback holds a callback function and associated session nonce for initial response handling.
44+ type InitialResponseCallback struct {
45+ Callback func (sessionNonce string ) error
46+ SessionNonce string
47+ }
48+
4249// Peer represents a peer capable of performing mutual authentication.
4350// It manages sessions, handles authentication handshakes, certificate requests and responses,
4451// and sending and receiving general messages over a transport layer.
@@ -51,14 +58,12 @@ type Peer struct {
5158 onGeneralMessageReceivedCallbacks map [int32 ]OnGeneralMessageReceivedCallback
5259 onCertificateReceivedCallbacks map [int32 ]OnCertificateReceivedCallback
5360 onCertificateRequestReceivedCallbacks map [int32 ]OnCertificateRequestReceivedCallback
54- onInitialResponseReceivedCallbacks map [int32 ]struct {
55- Callback func (sessionNonce string ) error
56- SessionNonce string
57- }
58- callbackIdCounter atomic.Int32
59- autoPersistLastSession bool
60- lastInteractedWithPeer * ec.PublicKey
61- logger * slog.Logger // Logger for debug messages
61+ onInitialResponseReceivedCallbacks map [int32 ]InitialResponseCallback
62+ callbacksMu sync.RWMutex
63+ callbackIdCounter atomic.Int32
64+ autoPersistLastSession bool
65+ lastInteractedWithPeer * ec.PublicKey
66+ logger * slog.Logger // Logger for debug messages
6267}
6368
6469// PeerOptions contains configuration options for creating a new Peer instance.
@@ -80,11 +85,8 @@ func NewPeer(cfg *PeerOptions) *Peer {
8085 onGeneralMessageReceivedCallbacks : make (map [int32 ]OnGeneralMessageReceivedCallback ),
8186 onCertificateReceivedCallbacks : make (map [int32 ]OnCertificateReceivedCallback ),
8287 onCertificateRequestReceivedCallbacks : make (map [int32 ]OnCertificateRequestReceivedCallback ),
83- onInitialResponseReceivedCallbacks : make (map [int32 ]struct {
84- Callback func (sessionNonce string ) error
85- SessionNonce string
86- }),
87- logger : cfg .Logger ,
88+ onInitialResponseReceivedCallbacks : make (map [int32 ]InitialResponseCallback ),
89+ logger : cfg .Logger ,
8890 }
8991
9092 // Use default logger if none provided
@@ -154,37 +156,67 @@ func (p *Peer) Stop() error {
154156// ListenForGeneralMessages registers a callback for general messages
155157func (p * Peer ) ListenForGeneralMessages (callback OnGeneralMessageReceivedCallback ) int32 {
156158 callbackID := p .callbackIdCounter .Add (1 )
159+ p .callbacksMu .Lock ()
157160 p .onGeneralMessageReceivedCallbacks [callbackID ] = callback
161+ p .callbacksMu .Unlock ()
158162 return callbackID
159163}
160164
161165// StopListeningForGeneralMessages removes a general message listener
162166func (p * Peer ) StopListeningForGeneralMessages (callbackID int32 ) {
167+ p .callbacksMu .Lock ()
163168 delete (p .onGeneralMessageReceivedCallbacks , callbackID )
169+ p .callbacksMu .Unlock ()
164170}
165171
166172// ListenForCertificatesReceived registers a callback for certificate reception
167173func (p * Peer ) ListenForCertificatesReceived (callback OnCertificateReceivedCallback ) int32 {
168174 callbackID := p .callbackIdCounter .Add (1 )
175+ p .callbacksMu .Lock ()
169176 p .onCertificateReceivedCallbacks [callbackID ] = callback
177+ p .callbacksMu .Unlock ()
170178 return callbackID
171179}
172180
173181// StopListeningForCertificatesReceived removes a certificate reception listener
174182func (p * Peer ) StopListeningForCertificatesReceived (callbackID int32 ) {
183+ p .callbacksMu .Lock ()
175184 delete (p .onCertificateReceivedCallbacks , callbackID )
185+ p .callbacksMu .Unlock ()
176186}
177187
178188// ListenForCertificatesRequested registers a callback for certificate requests
179189func (p * Peer ) ListenForCertificatesRequested (callback OnCertificateRequestReceivedCallback ) int32 {
180190 callbackID := p .callbackIdCounter .Add (1 )
191+ p .callbacksMu .Lock ()
181192 p .onCertificateRequestReceivedCallbacks [callbackID ] = callback
193+ p .callbacksMu .Unlock ()
182194 return callbackID
183195}
184196
185197// StopListeningForCertificatesRequested removes a certificate request listener
186198func (p * Peer ) StopListeningForCertificatesRequested (callbackID int32 ) {
199+ p .callbacksMu .Lock ()
187200 delete (p .onCertificateRequestReceivedCallbacks , callbackID )
201+ p .callbacksMu .Unlock ()
202+ }
203+
204+ // StopListeningForInitialResponse removes a certificate initial response listener
205+ func (p * Peer ) StopListeningForInitialResponse (callbackID int32 ) {
206+ p .callbacksMu .Lock ()
207+ defer p .callbacksMu .Unlock ()
208+ delete (p .onInitialResponseReceivedCallbacks , callbackID )
209+ }
210+
211+ // getInitialResponseCallbacks retrieves the initial response callbacks
212+ func (p * Peer ) getInitialResponseCallbacks () map [int32 ]InitialResponseCallback {
213+ p .callbacksMu .RLock ()
214+ defer p .callbacksMu .RUnlock ()
215+ callbacks := make (map [int32 ]InitialResponseCallback )
216+ for k , v := range p .onInitialResponseReceivedCallbacks {
217+ callbacks [k ] = v
218+ }
219+ return callbacks
188220}
189221
190222// ToPeer sends a message to a peer, initiating authentication if needed
@@ -335,16 +367,15 @@ func (p *Peer) initiateHandshake(ctx context.Context, peerIdentityKey *ec.Public
335367 // Register a callback for the response
336368 callbackID := p .callbackIdCounter .Add (1 )
337369
338- p .onInitialResponseReceivedCallbacks [callbackID ] = struct {
339- Callback func (sessionNonce string ) error
340- SessionNonce string
341- }{
370+ p .callbacksMu .Lock ()
371+ p .onInitialResponseReceivedCallbacks [callbackID ] = InitialResponseCallback {
342372 Callback : func (peerNonce string ) error {
343373 responseChan <- struct {}{}
344374 return nil
345375 },
346376 SessionNonce : sessionNonce ,
347377 }
378+ p .callbacksMu .Unlock ()
348379
349380 // TODO: replace maxWait with simply context with timeout
350381 ctxWithTimeout , cancel := context .WithTimeout (ctx , time .Duration (maxWaitTimeMs )* time .Millisecond )
@@ -354,18 +385,18 @@ func (p *Peer) initiateHandshake(ctx context.Context, peerIdentityKey *ec.Public
354385 err = p .transport .Send (ctx , initialRequest )
355386 if err != nil {
356387 close (responseChan )
357- delete ( p . onInitialResponseReceivedCallbacks , callbackID )
388+ p . StopListeningForInitialResponse ( callbackID )
358389 return nil , NewAuthError ("failed to send initial request" , err )
359390 }
360391
361392 // Wait for response or timeout
362393 select {
363394 case <- responseChan :
364395 close (responseChan )
365- delete ( p . onInitialResponseReceivedCallbacks , callbackID )
396+ p . StopListeningForInitialResponse ( callbackID )
366397 return session , nil
367398 case <- ctxWithTimeout .Done ():
368- delete ( p . onInitialResponseReceivedCallbacks , callbackID )
399+ p . StopListeningForInitialResponse ( callbackID )
369400 return nil , ErrTimeout
370401 }
371402}
@@ -619,7 +650,14 @@ func (p *Peer) handleInitialResponse(ctx context.Context, message *AuthMessage,
619650 // Certificates validated successfully, authenticate the session
620651 session .IsAuthenticated = true
621652
653+ p .callbacksMu .RLock ()
654+ callbacks := make ([]OnCertificateReceivedCallback , 0 , len (p .onCertificateReceivedCallbacks ))
622655 for _ , callback := range p .onCertificateReceivedCallbacks {
656+ callbacks = append (callbacks , callback )
657+ }
658+ p .callbacksMu .RUnlock ()
659+
660+ for _ , callback := range callbacks {
623661 err := callback (ctx , senderPublicKey , message .Certificates )
624662 if err != nil {
625663 return NewAuthError ("certificate received callback error" , err )
@@ -634,11 +672,11 @@ func (p *Peer) handleInitialResponse(ctx context.Context, message *AuthMessage,
634672
635673 p .lastInteractedWithPeer = message .IdentityKey
636674
637- for id , callback := range p .onInitialResponseReceivedCallbacks {
675+ for id , callback := range p .getInitialResponseCallbacks () {
638676 if callback .SessionNonce == session .SessionNonce {
639677 // Call the initial response callback with the peer's nonce
640678 err := callback .Callback (session .SessionNonce )
641- delete ( p . onInitialResponseReceivedCallbacks , id )
679+ p . StopListeningForInitialResponse ( id )
642680 if err != nil {
643681 return NewAuthError ("initial response received callback error" , err )
644682 }
@@ -657,16 +695,24 @@ func (p *Peer) handleInitialResponse(ctx context.Context, message *AuthMessage,
657695}
658696
659697func (p * Peer ) sendCertificates (ctx context.Context , message * AuthMessage ) error {
660- if len (p .onCertificateRequestReceivedCallbacks ) > 0 {
698+ p .callbacksMu .RLock ()
699+ hasCallbacks := len (p .onCertificateRequestReceivedCallbacks ) > 0
700+ if hasCallbacks {
701+ callbacks := make ([]OnCertificateRequestReceivedCallback , 0 , len (p .onCertificateRequestReceivedCallbacks ))
661702 for _ , callback := range p .onCertificateRequestReceivedCallbacks {
703+ callbacks = append (callbacks , callback )
704+ }
705+ p .callbacksMu .RUnlock ()
706+
707+ for _ , callback := range callbacks {
662708 err := callback (ctx , message .IdentityKey , message .RequestedCertificates )
663709 if err != nil {
664- // Log callback error but continue
665710 return fmt .Errorf ("on certificate request callback failed: %w" , err )
666711 }
667712 }
668713 return nil
669714 }
715+ p .callbacksMu .RUnlock ()
670716
671717 certs , err := utils .GetVerifiableCertificates (
672718 ctx ,
@@ -690,12 +736,6 @@ func (p *Peer) sendCertificates(ctx context.Context, message *AuthMessage) error
690736
691737// handleCertificateRequest processes a certificate request message
692738func (p * Peer ) handleCertificateRequest (ctx context.Context , message * AuthMessage , senderPublicKey * ec.PublicKey ) error {
693- // Validate the session exists and is authenticated
694- session , err := p .sessionManager .GetSession (senderPublicKey .ToDERHex ())
695- if err != nil || session == nil {
696- return ErrSessionNotFound
697- }
698-
699739 valid , err := utils .VerifyNonce (ctx , message .YourNonce , p .wallet , wallet.Counterparty {Type : wallet .CounterpartyTypeSelf })
700740 if err != nil {
701741 return fmt .Errorf ("failed to validate nonce: %w" , err )
@@ -704,6 +744,14 @@ func (p *Peer) handleCertificateRequest(ctx context.Context, message *AuthMessag
704744 return ErrInvalidNonce
705745 }
706746
747+ // Validate the session exists and is authenticated
748+ // Use YourNonce to look up the session, which uniquely identifies the correct session
749+ // even when multiple devices share the same identity key
750+ session , err := p .sessionManager .GetSession (message .YourNonce )
751+ if err != nil || session == nil {
752+ return ErrSessionNotFound
753+ }
754+
707755 // Update session timestamp
708756 session .LastUpdate = time .Now ().UnixMilli ()
709757 p .sessionManager .UpdateSession (session )
@@ -755,12 +803,6 @@ func (p *Peer) handleCertificateRequest(ctx context.Context, message *AuthMessag
755803
756804// handleCertificateResponse processes a certificate response message
757805func (p * Peer ) handleCertificateResponse (ctx context.Context , message * AuthMessage , senderPublicKey * ec.PublicKey ) error {
758- // Validate the session exists and is authenticated
759- session , err := p .sessionManager .GetSession (senderPublicKey .ToDERHex ())
760- if err != nil || session == nil {
761- return ErrSessionNotFound
762- }
763-
764806 valid , err := utils .VerifyNonce (ctx , message .YourNonce , p .wallet , wallet.Counterparty {Type : wallet .CounterpartyTypeSelf })
765807 if err != nil {
766808 return fmt .Errorf ("failed to validate nonce: %w" , err )
@@ -769,6 +811,14 @@ func (p *Peer) handleCertificateResponse(ctx context.Context, message *AuthMessa
769811 return ErrInvalidNonce
770812 }
771813
814+ // Validate the session exists and is authenticated
815+ // Use YourNonce to look up the session, which uniquely identifies the correct session
816+ // even when multiple devices share the same identity key
817+ session , err := p .sessionManager .GetSession (message .YourNonce )
818+ if err != nil || session == nil {
819+ return ErrSessionNotFound
820+ }
821+
772822 // Update session timestamp
773823 session .LastUpdate = time .Now ().UnixMilli ()
774824 p .sessionManager .UpdateSession (session )
@@ -846,7 +896,14 @@ func (p *Peer) handleCertificateResponse(ctx context.Context, message *AuthMessa
846896
847897 // TODO: maybe it should by default (if no callback) check if there are all required certificates
848898 // Notify certificate listeners
899+ p .callbacksMu .RLock ()
900+ callbacks := make ([]OnCertificateReceivedCallback , 0 , len (p .onCertificateReceivedCallbacks ))
849901 for _ , callback := range p .onCertificateReceivedCallbacks {
902+ callbacks = append (callbacks , callback )
903+ }
904+ p .callbacksMu .RUnlock ()
905+
906+ for _ , callback := range callbacks {
850907 err := callback (ctx , senderPublicKey , message .Certificates )
851908 if err != nil {
852909 return fmt .Errorf ("certificate received callback error: %w" , err )
@@ -868,7 +925,9 @@ func (p *Peer) handleGeneralMessage(ctx context.Context, message *AuthMessage, s
868925 }
869926
870927 // Validate the session exists and is authenticated
871- session , err := p .sessionManager .GetSession (senderPublicKey .ToDERHex ())
928+ // Use YourNonce to look up the session, which uniquely identifies the correct session
929+ // even when multiple devices share the same identity key
930+ session , err := p .sessionManager .GetSession (message .YourNonce )
872931 if err != nil || session == nil {
873932 return ErrSessionNotFound
874933 }
@@ -918,7 +977,14 @@ func (p *Peer) handleGeneralMessage(ctx context.Context, message *AuthMessage, s
918977 }
919978
920979 // Notify general message listeners
980+ p .callbacksMu .RLock ()
981+ callbacks := make ([]OnGeneralMessageReceivedCallback , 0 , len (p .onGeneralMessageReceivedCallbacks ))
921982 for _ , callback := range p .onGeneralMessageReceivedCallbacks {
983+ callbacks = append (callbacks , callback )
984+ }
985+ p .callbacksMu .RUnlock ()
986+
987+ for _ , callback := range callbacks {
922988 err := callback (ctx , senderPublicKey , message .Payload )
923989 if err != nil {
924990 // Log callback error but continue
0 commit comments