diff --git a/server/client.go b/server/client.go index b5bc04dce71..2addf8e372a 100644 --- a/server/client.go +++ b/server/client.go @@ -154,6 +154,7 @@ const ( compressionNegotiated // Marks if this connection has negotiated compression level with remote. didTLSFirst // Marks if this connection requested and was accepted doing the TLS handshake first (prior to INFO). isSlowConsumer // Marks connection as a slow consumer. + firstPong // Marks if this is the first PONG received ) // set the flag (would be equivalent to set the boolean to true) @@ -2587,6 +2588,14 @@ func (c *client) processPong() { c.rtt = computeRTT(c.rttStart) srv := c.srv reorderGWs := c.kind == GATEWAY && c.gw.outbound + firstPong := c.flags.setIfNotSet(firstPong) + var ri *routeInfo + // When receiving the first PONG, for a route with pooling, we may be + // instructed to start a new route. + if firstPong && c.kind == ROUTER && c.route != nil { + ri = c.route.startNewRoute + c.route.startNewRoute = nil + } // If compression is currently active for a route/leaf connection, if the // compression configuration is s2_auto, check if we should change // the compression level. @@ -2605,6 +2614,11 @@ func (c *client) processPong() { if reorderGWs { srv.gateway.orderOutboundConnections() } + if ri != nil { + srv.startGoRoutine(func() { + srv.connectToRoute(ri.url, ri.rtype, true, ri.gossipMode, _EMPTY_) + }) + } } // Select the s2 compression level based on the client's current RTT and the configured diff --git a/server/route.go b/server/route.go index 588418f5145..008e6ede51c 100644 --- a/server/route.go +++ b/server/route.go @@ -87,6 +87,17 @@ type route struct { // Transient value used to set the Info.GossipMode when initiating // an implicit route and sending to the remote. gossipMode byte + // This will be set in case of pooling so that a route can trigger + // the creation of the next after receiving the first PONG, ensuring + // that authentication did not fail. + startNewRoute *routeInfo +} + +// This contains the information required to create a new route. +type routeInfo struct { + url *url.URL + rtype RouteType + gossipMode byte } // Do not change the values/order since they are exchanged between servers. @@ -2380,20 +2391,18 @@ func (s *Server) addRoute(c *client, didSolicit, sendDelayedInfo bool, gossipMod // Send the subscriptions interest. s.sendSubsToRoute(c, idx, _EMPTY_) - // In pool mode, if we did not yet reach the cap, try to connect a new connection + // In pool mode, if we did not yet reach the cap, try to connect a new connection, + // but do so only after receiving the first PONG to our PING, which will ensure + // that we have proper authentication. if pool && didSolicit && sz != effectivePoolSize { - s.startGoRoutine(func() { - select { - case <-time.After(time.Duration(rand.Intn(100)) * time.Millisecond): - case <-s.quitCh: - // Doing this here and not as a defer because connectToRoute is also - // calling s.grWG.Done() on exit, so we do this only if we don't - // invoke connectToRoute(). - s.grWG.Done() - return - } - s.connectToRoute(url, rtype, true, gossipMode, _EMPTY_) - }) + c.mu.Lock() + c.route.startNewRoute = &routeInfo{ + url: url, + rtype: rtype, + gossipMode: gossipMode, + } + c.sendPing() + c.mu.Unlock() } } s.mu.Unlock() diff --git a/server/routes_test.go b/server/routes_test.go index be117215113..add45703d86 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -328,10 +328,12 @@ func checkClusterFormed(t testing.TB, servers ...*Server) { if a == b { continue } - if b.getOpts().Cluster.PoolSize < 0 { + bo := b.getOpts() + if ps := bo.Cluster.PoolSize; ps < 0 { total++ } else { - total += nr + bps := ps + len(bo.Cluster.PinnedAccounts) + total += max(nr, bps) } } enr = append(enr, total) @@ -3740,6 +3742,62 @@ func TestRoutePoolWithOlderServerConnectAndReconnect(t *testing.T) { checkRepeatConnect() } +func TestRoutePoolBadAuthNoRunawayCreateRoute(t *testing.T) { + conf1 := createConfFile(t, []byte(` + server_name: "S1" + listen: "127.0.0.1:-1" + cluster { + name: "local" + listen: "127.0.0.1:-1" + pool_size: 4 + authorization { + user: "correct" + password: "correct" + timeout: 5 + } + } + `)) + s1, o1 := RunServerWithConfig(conf1) + defer s1.Shutdown() + + l := &captureErrorLogger{errCh: make(chan string, 100)} + s1.SetLogger(l, false, false) + + tmpl := ` + server_name: "S2" + listen: "127.0.0.1:-1" + cluster { + name: "local" + listen: "127.0.0.1:-1" + pool_size: 5 + routes: ["nats://%s@127.0.0.1:%d"] + } + ` + conf2 := createConfFile(t, fmt.Appendf(nil, tmpl, "incorrect:incorrect", o1.Cluster.Port)) + s2, _ := RunServerWithConfig(conf2) + defer s2.Shutdown() + + deadline := time.Now().Add(2 * time.Second) + var errors int + for time.Now().Before(deadline) { + select { + case <-l.errCh: + errors++ + default: + } + } + // We should not get that many errors now. In the past, we would get more + // than 200 for the 2 sec wait. + if errors > 10 { + t.Fatalf("Unexpected number of errors: %v", errors) + } + + // Reload with proper credentials. + reloadUpdateConfig(t, s2, conf2, fmt.Sprintf(tmpl, "correct:correct", o1.Cluster.Port)) + // Ensure we can connect. + checkClusterFormed(t, s1, s2) +} + func TestRouteCompressionOptions(t *testing.T) { org := testDefaultClusterCompression testDefaultClusterCompression = _EMPTY_