diff --git a/server/accounts.go b/server/accounts.go index 7baabbb6646..55bd0774746 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -378,6 +378,21 @@ func (a *Account) getClients() []*client { return clients } +// Returns a slice of external (non-internal) clients stored in the account, or nil if none is present. +// Lock is held on entry. +func (a *Account) getExternalClientsLocked() []*client { + if len(a.clients) == 0 { + return nil + } + var clients []*client + for c := range a.clients { + if !isInternalClient(c.kind) { + clients = append(clients, c) + } + } + return clients +} + // Called to track a remote server and connections and leafnodes it // has for this account. func (a *Account) updateRemoteServer(m *AccountNumConns) []*client { @@ -398,7 +413,8 @@ func (a *Account) updateRemoteServer(m *AccountNumConns) []*client { // conservative and bit harsh here. Clients will reconnect if we over compensate. var clients []*client if mtce { - clients = a.getClientsLocked() + clients = a.getExternalClientsLocked() + // Sort in reverse chronological. slices.SortFunc(clients, func(i, j *client) int { return -i.start.Compare(j.start) }) over := (len(a.clients) - int(a.sysclients) + int(a.nrclients)) - int(a.mconns) diff --git a/server/jetstream_cluster_4_test.go b/server/jetstream_cluster_4_test.go index 105e0df52f0..516ae2459f5 100644 --- a/server/jetstream_cluster_4_test.go +++ b/server/jetstream_cluster_4_test.go @@ -36,6 +36,7 @@ import ( "testing" "time" + "github.com/nats-io/jwt/v2" "github.com/nats-io/nats.go" "github.com/nats-io/nuid" ) @@ -6904,3 +6905,176 @@ func TestJetStreamClusterMultiLeaderR3Config(t *testing.T) { }) } } + +func TestJetStreamClusterAccountMaxConnectionsReconnect(t *testing.T) { + conf := ` + listen: 127.0.0.1:-1 + http: -1 + server_name: %s + jetstream: { + store_dir: '%s', + } + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + server_tags: ["test"] + system_account: sys + no_auth_user: js + accounts { + sys { users = [ { user: sys, pass: sys } ] } + js { + jetstream = enabled + users = [ { user: js, pass: js } ] + limits { + max_connections: 5 + } + } + } + ` + c := createJetStreamClusterWithTemplate(t, conf, "R3CONNECT", 3) + defer c.shutdown() + var conns []*nats.Conn + + disconnects := make([]chan error, 0) + for i := 1; i <= 5; i++ { + disconnectCh := make(chan error) + c, _ := jsClientConnect(t, c.servers[0], nats.UserInfo("js", "js"), nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { + disconnectCh <- err + })) + defer c.Close() + conns = append(conns, c) + disconnects = append(disconnects, disconnectCh) + // Small delay to ensure distinct start times. + time.Sleep(10 * time.Millisecond) + } + s := c.servers[0] + acc, err := s.lookupAccount("js") + require_NoError(t, err) + + acc.mu.RLock() + clients := acc.getClientsLocked() + numConnections := acc.NumConnections() + jsClients := acc.sysclients + totalClients := len(clients) + acc.mu.RUnlock() + + require_Equal(t, numConnections, 5) + require_Equal(t, jsClients, 0) + require_Equal(t, totalClients, 5) + + nc := conns[0] + js, _ := nc.JetStream() + for i := 0; i < 10; i++ { + _, err := js.AddStream(&nats.StreamConfig{ + Name: fmt.Sprintf("foo:%d", i), + Subjects: []string{fmt.Sprintf("foo.%d", i)}, + }) + require_NoError(t, err) + + _, err = js.Publish(fmt.Sprintf("foo.%d", i), []byte("hello"), nats.AckWait(5*time.Second)) + require_NoError(t, err) + } + + acc.mu.RLock() + clients = acc.getClientsLocked() + numConnections = acc.NumConnections() + jsClients = acc.sysclients + totalClients = len(clients) + acc.mu.RUnlock() + + require_Equal(t, numConnections, 5) + require_Equal(t, jsClients, 20) + require_Equal(t, totalClients, 25) + + checkFor(t, 30*time.Second, 200*time.Millisecond, func() error { + for i := 0; i < 10; i++ { + _, err := js.Publish(fmt.Sprintf("foo.%d", i), []byte("hello"), nats.AckWait(5*time.Second)) + if err != nil { + return err + } + } + return nil + }) + + // Force account update to trigger connection limit enforcement. + accClaims := jwt.NewAccountClaims(acc.Name) + accClaims.Limits.Conn = 1 + accClaims.Limits.MemoryStorage = -1 + accClaims.Limits.DiskStorage = -1 + accClaims.Limits.Streams = -1 + accClaims.Limits.Consumer = -1 + + // Update server, before this would have disconnected JS internal clients with + // 'JETSTREAM - maximum account active connections exceeded'. + s.UpdateAccountClaims(acc, accClaims) + + // Allow some time for enforcement. + time.Sleep(100 * time.Millisecond) + + acc, err = s.lookupAccount("js") + require_NoError(t, err) + + acc.mu.RLock() + clients = acc.getClientsLocked() + numConnections = acc.NumConnections() + jsClients = acc.sysclients + totalClients = len(clients) + acc.mu.RUnlock() + + // JETSTREAM internal clients should still linger after reducing connections. + require_Equal(t, numConnections, 5) + require_Equal(t, jsClients, 20) + require_Equal(t, totalClients, 20) + + // Wait for disconnections from the most recent client. + disconnectCh := disconnects[2] + select { + case <-disconnectCh: + case <-time.After(2 * time.Second): + t.Fatal("Expected newest connection to disconnect!") + } + + checkFor(t, 30*time.Second, 200*time.Millisecond, func() error { + activeConnections := 0 + for _, conn := range conns { + if !conn.IsClosed() { + activeConnections++ + } + } + if activeConnections < 5 { + return fmt.Errorf("Unexpected number of connections: %d", activeConnections) + } + return nil + }) + + // Force account update to trigger connection limit enforcement. + accClaims = jwt.NewAccountClaims(acc.Name) + accClaims.Limits.Conn = 10 + accClaims.Limits.MemoryStorage = -1 + accClaims.Limits.DiskStorage = -1 + accClaims.Limits.Streams = -1 + accClaims.Limits.Consumer = -1 + + // Update all servers then confirm that internal JS clients should work + // and clients have reconnected. + for _, s := range c.servers { + acc, err := s.lookupAccount("js") + require_NoError(t, err) + s.UpdateAccountClaims(acc, accClaims) + } + checkFor(t, 30*time.Second, 200*time.Millisecond, func() error { + for _, nc := range conns { + js, _ := nc.JetStream() + for i := 0; i < 10; i++ { + stream := fmt.Sprintf("foo.%d", i) + _, err := js.Publish(stream, []byte("hello"), nats.AckWait(5*time.Second)) + if err != nil { + return err + } + } + } + return nil + }) +}