Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions management/internals/shared/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,19 +307,21 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return mapError(ctx, err)
}

streamStartTime := time.Now().UTC()

err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort)
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
return err
}

updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
return err
}

Expand All @@ -336,7 +338,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S

s.syncSem.Add(-1)

return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, streamStartTime)
}

func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
Expand Down Expand Up @@ -404,7 +406,7 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
}

// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for {
select {
Expand All @@ -416,11 +418,11 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg

if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
Expand All @@ -429,32 +431,32 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return srv.Context().Err()
}
}
}

// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed processing update message")
}

encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
Expand Down Expand Up @@ -486,15 +488,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
return nil
}

func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()

s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
}

func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) {
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
Expand Down
16 changes: 14 additions & 2 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -1684,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return peer, netMap, postureChecks, dnsfwdPort, nil
}

func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}

if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}

err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
Comment thread
crn4 marked this conversation as resolved.
Comment thread
crn4 marked this conversation as resolved.
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
Expand Down
2 changes: 1 addition & 1 deletion management/server/account/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ type Manager interface {
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
Expand Down
55 changes: 55 additions & 0 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,61 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}

func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")

key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()

_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
require.NoError(t, err, "unable to add peer")

t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")

peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")

streamStartTime := time.Now().UTC()

err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)

peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
})

t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")

peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")

streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)

err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)

peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
})
}

func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
Expand Down
5 changes: 2 additions & 3 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,8 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}

func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me
panic("implement me")
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
return nil
}

func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
Expand Down
Loading