diff --git a/conf/lex.go b/conf/lex.go index 013b8838663..7fd2bf50348 100644 --- a/conf/lex.go +++ b/conf/lex.go @@ -1,4 +1,4 @@ -// Copyright 2013-2018 The NATS Authors +// Copyright 2013-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -263,7 +263,8 @@ func lexTop(lx *lexer) stateFn { switch r { case topOptStart: - return lexSkip(lx, lexTop) + lx.push(lexTop) + return lexSkip(lx, lexBlockStart) case commentHashStart: lx.push(lexTop) return lexCommentStart @@ -318,6 +319,71 @@ func lexTopValueEnd(lx *lexer) stateFn { "comment or EOF, but got '%v' instead.", r) } +func lexBlockStart(lx *lexer) stateFn { + r := lx.next() + if unicode.IsSpace(r) { + return lexSkip(lx, lexBlockStart) + } + + switch r { + case topOptStart: + lx.push(lexBlockEnd) + return lexSkip(lx, lexBlockStart) + case commentHashStart: + lx.push(lexBlockEnd) + return lexCommentStart + case commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexBlockEnd) + return lexCommentStart + } + lx.backup() + fallthrough + case eof: + if lx.pos > lx.start { + return lx.errorf("Unexpected EOF.") + } + lx.emit(itemEOF) + return nil + } + + // At this point, the only valid item can be a key, so we back up + // and let the key lexer do the rest. + lx.backup() + lx.push(lexBlockEnd) + return lexKeyStart +} + +// lexBlockEnd is entered whenever a block-level value has been consumed. +// It must see only whitespace, and will turn back to lexTop upon a "}". +func lexBlockEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case r == commentHashStart: + // a comment will read to a new line for us. + lx.push(lexBlockEnd) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexBlockEnd) + return lexCommentStart + } + lx.backup() + fallthrough + case isNL(r) || isWhitespace(r): + return lexBlockEnd + case r == optValTerm || r == topOptValTerm: + lx.ignore() + return lexBlockStart + case r == topOptTerm: + lx.ignore() + return lx.pop() + } + return lx.errorf("Expected a block-level value to end with a '}', but got '%v' instead.", r) +} + // lexKeyStart consumes a key name up until the first non-whitespace character. // lexKeyStart will ignore whitespace. It will also eat enclosing quotes. func lexKeyStart(lx *lexer) stateFn { diff --git a/conf/parse_test.go b/conf/parse_test.go index 8cf1ea98f3b..53644b0b343 100644 --- a/conf/parse_test.go +++ b/conf/parse_test.go @@ -740,3 +740,80 @@ func TestJSONParseCompat(t *testing.T) { }) } } + +func TestBlocks(t *testing.T) { + for _, test := range []struct { + name string + input string + expected map[string]any + err string + linepos string + }{ + { + "inline block", + `{ listen: 0.0.0.0:4222 }`, + map[string]any{ + "listen": "0.0.0.0:4222", + }, + "", "", + }, + { + "newline block", + `{ + listen: 0.0.0.0:4222 + }`, + map[string]any{ + "listen": "0.0.0.0:4222", + }, + "", "", + }, + { + "newline block with trailing comment", + ` + { + listen: 0.0.0.0:4222 + } + # wibble + `, + map[string]any{ + "listen": "0.0.0.0:4222", + }, + "", "", + }, + { + "nested newline blocks with trailing comment", + ` + { + { + listen: 0.0.0.0:4222 // random comment + } + # wibble1 + } + # wibble2 + `, + map[string]any{ + "listen": "0.0.0.0:4222", + }, + "", "", + }, + } { + t.Run(test.name, func(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "nats.conf-") + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(f.Name(), []byte(test.input), 066); err != nil { + t.Error(err) + } + if m, err := ParseFile(f.Name()); err == nil { + if !reflect.DeepEqual(m, test.expected) { + t.Fatalf("Not Equal:\nReceived: '%+v'\nExpected: '%+v'\n", m, test.expected) + } + } else if !strings.Contains(err.Error(), test.err) || !strings.Contains(err.Error(), test.linepos) { + t.Errorf("expected invalid conf error, got: %v", err) + } else if err != nil { + t.Error(err) + } + }) + } +} diff --git a/go.mod b/go.mod index 4209d33f181..417445174ce 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.20 require ( github.com/klauspost/compress v1.17.7 github.com/minio/highwayhash v1.0.2 - github.com/nats-io/jwt/v2 v2.5.6 + github.com/nats-io/jwt/v2 v2.5.7 github.com/nats-io/nats.go v1.34.1 github.com/nats-io/nkeys v0.4.7 github.com/nats-io/nuid v1.0.1 go.uber.org/automaxprocs v1.5.3 - golang.org/x/crypto v0.22.0 + golang.org/x/crypto v0.23.0 golang.org/x/sys v0.20.0 golang.org/x/time v0.5.0 ) diff --git a/go.sum b/go.sum index d634e831677..42b7aba0a03 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,8 @@ github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLA github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/nats-io/jwt/v2 v2.5.6 h1:Cp618+z4q042sWqHiSoIHFT08OZtAskui0hTmRfmGGQ= -github.com/nats-io/jwt/v2 v2.5.6/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= +github.com/nats-io/jwt/v2 v2.5.7 h1:j5lH1fUXCnJnY8SsQeB/a/z9Azgu2bYIDvtPVNdxe2c= +github.com/nats-io/jwt/v2 v2.5.7/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= github.com/nats-io/nats.go v1.34.1 h1:syWey5xaNHZgicYBemv0nohUPPmaLteiBEUT6Q5+F/4= github.com/nats-io/nats.go v1.34.1/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= @@ -16,8 +16,8 @@ github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4 github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/server/auth.go b/server/auth.go index 97106343450..700e5741442 100644 --- a/server/auth.go +++ b/server/auth.go @@ -1464,7 +1464,8 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error { switch ctuc { case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS, - jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS: + jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS, + jwt.ConnectionTypeInProcess: default: return fmt.Errorf("unknown connection type %q", ct) } diff --git a/server/client.go b/server/client.go index 66c8d1e7679..acaa36ac51b 100644 --- a/server/client.go +++ b/server/client.go @@ -279,6 +279,7 @@ type client struct { trace bool echo bool noIcb bool + iproc bool // In-Process connection, set at creation and immutable. tags jwt.TagList nameTag string @@ -2349,24 +2350,11 @@ func (c *client) generateClientInfoJSON(info Info) []byte { info.MaxPayload = c.mpay if c.isWebsocket() { info.ClientConnectURLs = info.WSConnectURLs - if c.srv != nil { // Otherwise lame duck info can panic - c.srv.websocket.mu.RLock() - info.TLSAvailable = c.srv.websocket.tls - if c.srv.websocket.tls && c.srv.websocket.server != nil { - if tc := c.srv.websocket.server.TLSConfig; tc != nil { - info.TLSRequired = !tc.InsecureSkipVerify - } - } - if c.srv.websocket.listener != nil { - laddr := c.srv.websocket.listener.Addr().String() - if h, p, err := net.SplitHostPort(laddr); err == nil { - if p, err := strconv.Atoi(p); err == nil { - info.Host = h - info.Port = p - } - } - } - c.srv.websocket.mu.RUnlock() + // Otherwise lame duck info can panic + if c.srv != nil { + ws := &c.srv.websocket + info.TLSAvailable, info.TLSRequired = ws.tls, ws.tls + info.Host, info.Port = ws.host, ws.port } } info.WSConnectURLs = nil @@ -5745,7 +5733,8 @@ func convertAllowedConnectionTypes(cts []string) (map[string]struct{}, error) { switch i { case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS, - jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS: + jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS, + jwt.ConnectionTypeInProcess: m[i] = struct{}{} default: unknown = append(unknown, i) @@ -5772,7 +5761,11 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { case CLIENT: switch c.clientType() { case NATS: - want = jwt.ConnectionTypeStandard + if c.iproc { + want = jwt.ConnectionTypeInProcess + } else { + want = jwt.ConnectionTypeStandard + } case WS: want = jwt.ConnectionTypeWebsocket case MQTT: diff --git a/server/client_test.go b/server/client_test.go index ed42b9568cc..a88fd5e2a0c 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -2962,3 +2962,96 @@ func TestRemoveHeaderIfPrefixPresent(t *testing.T) { t.Fatalf("Expected headers to be stripped, got %q", hdr) } } + +func TestInProcessAllowedConnectionType(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + accounts { + A { users: [{user: "test", password: "pwd", allowed_connection_types: ["%s"]}] } + } + write_deadline: "500ms" + ` + for _, test := range []struct { + name string + ct string + inProcessOnly bool + }{ + {"conf inprocess", jwt.ConnectionTypeInProcess, true}, + {"conf standard", jwt.ConnectionTypeStandard, false}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, test.ct))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // Create standard connection + nc, err := nats.Connect(s.ClientURL(), nats.UserInfo("test", "pwd")) + if test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected standard connection to fail, it did not") + } + // Works if nc is nil (which it will if only in-process are allowed) + nc.Close() + + // Create inProcess connection + nc, err = nats.Connect(_EMPTY_, nats.UserInfo("test", "pwd"), nats.InProcessServer(s)) + if !test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected in-process connection to fail, it did not") + } + // Works if nc is nil (which it will if only standard are allowed) + nc.Close() + }) + } + for _, test := range []struct { + name string + ct string + inProcessOnly bool + }{ + {"jwt inprocess", jwt.ConnectionTypeInProcess, true}, + {"jwt standard", jwt.ConnectionTypeStandard, false}, + } { + t.Run(test.name, func(t *testing.T) { + skp, _ := nkeys.FromSeed(oSeed) + spub, _ := skp.PublicKey() + + o := defaultServerOptions + o.TrustedKeys = []string{spub} + o.WriteDeadline = 500 * time.Millisecond + s := RunServer(&o) + defer s.Shutdown() + + buildMemAccResolver(s) + + kp, _ := nkeys.CreateAccount() + aPub, _ := kp.PublicKey() + claim := jwt.NewAccountClaims(aPub) + aJwt, err := claim.Encode(oKp) + require_NoError(t, err) + + addAccountToMemResolver(s, aPub, aJwt) + + creds := createUserWithLimit(t, kp, time.Time{}, + func(j *jwt.UserPermissionLimits) { + j.AllowedConnectionTypes.Add(test.ct) + }) + // Create standard connection + nc, err := nats.Connect(s.ClientURL(), nats.UserCredentials(creds)) + if test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected standard connection to fail, it did not") + } + // Works if nc is nil (which it will if only in-process are allowed) + nc.Close() + + // Create inProcess connection + nc, err = nats.Connect(_EMPTY_, nats.UserCredentials(creds), nats.InProcessServer(s)) + if !test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected in-process connection to fail, it did not") + } + // Works if nc is nil (which it will if only standard are allowed) + nc.Close() + }) + } +} diff --git a/server/config_check_test.go b/server/config_check_test.go index a9ec00cf1ae..41144cd811c 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1817,6 +1817,23 @@ func TestConfigCheck(t *testing.T) { errorLine: 9, errorPos: 9, }, + { + name: "invalid duration for remote leafnode first info timeout", + config: ` + leafnodes { + port: -1 + remotes [ + { + url: "nats://127.0.0.1:123" + first_info_timeout: abc + } + ] + } + `, + err: fmt.Errorf("error parsing first_info_timeout: time: invalid duration %q", "abc"), + errorLine: 7, + errorPos: 8, + }, { name: "show warnings on empty configs without values", config: ``, diff --git a/server/consumer.go b/server/consumer.go index fdaa0f58647..d75f6bb66a8 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -345,6 +345,7 @@ type consumer struct { rdq []uint64 rdqi avl.SequenceSet rdc map[uint64]uint64 + replies map[uint64]string maxdc uint64 waiting *waitQueue cfg ConsumerConfig @@ -1154,6 +1155,7 @@ func (o *consumer) setLeader(isLeader bool) { o.mu.RLock() mset, closed := o.mset, o.closed movingToClustered := o.node != nil && o.pch == nil + movingToNonClustered := o.node == nil && o.pch != nil wasLeader := o.leader.Swap(isLeader) o.mu.RUnlock() @@ -1177,6 +1179,17 @@ func (o *consumer) setLeader(isLeader bool) { } } o.mu.Unlock() + } else if movingToNonClustered { + // We are moving from clustered to non-clustered now. + // Set pch to nil so if we scale back up we will recreate the loopAndForward from above. + o.mu.Lock() + pch := o.pch + o.pch = nil + select { + case pch <- struct{}{}: + default: + } + o.mu.Unlock() } return } @@ -1356,6 +1369,8 @@ func (o *consumer) setLeader(isLeader bool) { // If we were the leader make sure to drain queued up acks. if wasLeader { o.ackMsgs.drain() + // Also remove any pending replies since we should not be the one to respond at this point. + o.replies = nil } o.mu.Unlock() } @@ -1955,9 +1970,9 @@ func configsEqualSansDelivery(a, b ConsumerConfig) bool { // Helper to send a reply to an ack. func (o *consumer) sendAckReply(subj string) { - o.mu.Lock() - defer o.mu.Unlock() - o.sendAdvisory(subj, nil) + o.mu.RLock() + defer o.mu.RUnlock() + o.outq.sendMsg(subj, nil) } type jsAckMsg struct { @@ -2015,9 +2030,11 @@ func (o *consumer) processAck(subject, reply string, hdr int, rmsg []byte) { switch { case len(msg) == 0, bytes.Equal(msg, AckAck), bytes.Equal(msg, AckOK): - o.processAckMsg(sseq, dseq, dc, true) + o.processAckMsg(sseq, dseq, dc, reply, true) + // We handle replies for acks in updateAcks + skipAckReply = true case bytes.HasPrefix(msg, AckNext): - o.processAckMsg(sseq, dseq, dc, true) + o.processAckMsg(sseq, dseq, dc, _EMPTY_, true) o.processNextMsgRequest(reply, msg[len(AckNext):]) skipAckReply = true case bytes.HasPrefix(msg, AckNak): @@ -2029,7 +2046,9 @@ func (o *consumer) processAck(subject, reply string, hdr int, rmsg []byte) { if buf := msg[len(AckTerm):]; len(buf) > 0 { reason = string(bytes.TrimSpace(buf)) } - o.processTerm(sseq, dseq, dc, reason) + o.processTerm(sseq, dseq, dc, reason, reply) + // We handle replies for acks in updateAcks + skipAckReply = true } // Ack the ack if requested. @@ -2064,6 +2083,13 @@ func (o *consumer) updateSkipped(seq uint64) { } func (o *consumer) loopAndForwardProposals(qch chan struct{}) { + // On exit make sure we nil out pch. + defer func() { + o.mu.Lock() + o.pch = nil + o.mu.Unlock() + }() + o.mu.RLock() node, pch := o.node, o.pch o.mu.RUnlock() @@ -2074,7 +2100,7 @@ func (o *consumer) loopAndForwardProposals(qch chan struct{}) { forwardProposals := func() error { o.mu.Lock() - if o.node != node || node.State() != Leader { + if o.node == nil || o.node.State() != Leader { o.mu.Unlock() return errors.New("no longer leader") } @@ -2161,8 +2187,17 @@ func (o *consumer) updateDelivered(dseq, sseq, dc uint64, ts int64) { o.ldt = time.Now() } +// Used to remember a pending ack reply in a replicated consumer. +// Lock should be held. +func (o *consumer) addAckReply(sseq uint64, reply string) { + if o.replies == nil { + o.replies = make(map[uint64]string) + } + o.replies[sseq] = reply +} + // Lock should be held. -func (o *consumer) updateAcks(dseq, sseq uint64) { +func (o *consumer) updateAcks(dseq, sseq uint64, reply string) { if o.node != nil { // Inline for now, use variable compression. var b [2*binary.MaxVarintLen64 + 1]byte @@ -2171,8 +2206,15 @@ func (o *consumer) updateAcks(dseq, sseq uint64) { n += binary.PutUvarint(b[n:], dseq) n += binary.PutUvarint(b[n:], sseq) o.propose(b[:n]) + if reply != _EMPTY_ { + o.addAckReply(sseq, reply) + } } else if o.store != nil { o.store.UpdateAcks(dseq, sseq) + if reply != _EMPTY_ { + // Already locked so send direct. + o.outq.sendMsg(reply, nil) + } } // Update activity. o.lat = time.Now() @@ -2362,9 +2404,9 @@ func (o *consumer) processNak(sseq, dseq, dc uint64, nak []byte) { } // Process a TERM -func (o *consumer) processTerm(sseq, dseq, dc uint64, reason string) { +func (o *consumer) processTerm(sseq, dseq, dc uint64, reason, reply string) { // Treat like an ack to suppress redelivery. - o.processAckMsg(sseq, dseq, dc, false) + o.processAckMsg(sseq, dseq, dc, reply, false) o.mu.Lock() defer o.mu.Unlock() @@ -2467,6 +2509,7 @@ func (o *consumer) applyState(state *ConsumerState) { // This is on startup or leader change. We want to check pending // sooner in case there are inconsistencies etc. Pick between 500ms - 1.5s delay := 500*time.Millisecond + time.Duration(rand.Int63n(1000))*time.Millisecond + // If normal is lower than this just use that. if o.cfg.AckWait < delay { delay = o.ackWait(0) @@ -2692,7 +2735,7 @@ func (o *consumer) sampleAck(sseq, dseq, dc uint64) { o.sendAdvisory(o.ackEventT, j) } -func (o *consumer) processAckMsg(sseq, dseq, dc uint64, doSample bool) { +func (o *consumer) processAckMsg(sseq, dseq, dc uint64, reply string, doSample bool) { o.mu.Lock() if o.closed { o.mu.Unlock() @@ -2738,7 +2781,6 @@ func (o *consumer) processAckMsg(sseq, dseq, dc uint64, doSample bool) { o.adflr = o.dseq - 1 } } - // We do these regardless. delete(o.rdc, sseq) o.removeFromRedeliverQueue(sseq) case AckAll: @@ -2764,7 +2806,7 @@ func (o *consumer) processAckMsg(sseq, dseq, dc uint64, doSample bool) { } // Update underlying store. - o.updateAcks(dseq, sseq) + o.updateAcks(dseq, sseq, reply) clustered := o.node != nil @@ -3656,7 +3698,7 @@ func (o *consumer) checkAckFloor() { o.mu.RUnlock() // If it was pending for us, get rid of it. if isPending { - o.processTerm(seq, p.Sequence, rdc, ackTermLimitsReason) + o.processTerm(seq, p.Sequence, rdc, ackTermLimitsReason, _EMPTY_) } } } else if numPending > 0 { @@ -3681,7 +3723,7 @@ func (o *consumer) checkAckFloor() { for i := 0; i < len(toTerm); i += 3 { seq, dseq, rdc := toTerm[i], toTerm[i+1], toTerm[i+2] - o.processTerm(seq, dseq, rdc, ackTermLimitsReason) + o.processTerm(seq, dseq, rdc, ackTermLimitsReason, _EMPTY_) } } @@ -3728,6 +3770,7 @@ func (o *consumer) processInboundAcks(qch chan struct{}) { o.mu.RLock() s, mset := o.srv, o.mset hasInactiveThresh := o.cfg.InactiveThreshold > 0 + o.mu.RUnlock() if s == nil || mset == nil { @@ -3866,7 +3909,7 @@ func (o *consumer) loopAndGatherMsgs(qch chan struct{}) { o.mu.Lock() // consumer is closed when mset is set to nil. - if o.mset == nil { + if o.closed || o.mset == nil { o.mu.Unlock() return } @@ -4262,7 +4305,7 @@ func (o *consumer) deliverMsg(dsubj, ackReply string, pmsg *jsPubMsg, dc uint64, if o.node == nil || o.cfg.Direct { mset.ackq.push(seq) } else { - o.updateAcks(dseq, seq) + o.updateAcks(dseq, seq, _EMPTY_) } } } @@ -5247,7 +5290,7 @@ func (o *consumer) decStreamPending(sseq uint64, subj string) { if wasPending { // We could have the lock for the stream so do this in a go routine. // TODO(dlc) - We should do this with ipq vs naked go routines. - go o.processTerm(sseq, p.Sequence, rdc, ackTermUnackedLimitsReason) + go o.processTerm(sseq, p.Sequence, rdc, ackTermUnackedLimitsReason, _EMPTY_) } } diff --git a/server/filestore.go b/server/filestore.go index 976796910c1..a71237c61e8 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -3983,8 +3983,9 @@ func (mb *msgBlock) compact() { if !isDeleted(seq) { // Check for tombstones. if seq&tbit != 0 { - // If we are last mb we should consider to keep these unless the tombstone reflects a seq in this mb. - if mb == mb.fs.lmb && seq < fseq { + seq = seq &^ tbit + // If this entry is for a lower seq than ours then keep around. + if seq < fseq { nbuf = append(nbuf, buf[index:index+rl]...) } } else { @@ -4040,6 +4041,9 @@ func (mb *msgBlock) compact() { return } + // Make sure to sync + mb.needSync = true + // Capture the updated rbytes. mb.rbytes = uint64(len(nbuf)) @@ -6881,6 +6885,9 @@ func (fs *fileStore) Compact(seq uint64) (uint64, error) { if smb != fs.lmb { smb.dirtyCloseWithRemove(true) deleted++ + } else { + // Make sure to sync changes. + smb.needSync = true } // Update fs first here as well. fs.state.FirstSeq = atomic.LoadUint64(&smb.last.seq) + 1 diff --git a/server/filestore_test.go b/server/filestore_test.go index 547d839aed5..5b119971c6b 100644 --- a/server/filestore_test.go +++ b/server/filestore_test.go @@ -6796,6 +6796,53 @@ func TestFileStoreFSSExpireNumPending(t *testing.T) { fs.mu.RUnlock() } +// We want to ensure that recovery of deleted messages survives no index.db and compactions. +func TestFileStoreRecoverWithRemovesAndNoIndexDB(t *testing.T) { + sd := t.TempDir() + fs, err := newFileStore( + FileStoreConfig{StoreDir: sd, BlockSize: 250}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage}) + require_NoError(t, err) + defer fs.Stop() + + msg := []byte("abc") + for i := 1; i <= 10; i++ { + fs.StoreMsg(fmt.Sprintf("foo.%d", i), nil, msg) + } + fs.RemoveMsg(1) + fs.RemoveMsg(2) + fs.RemoveMsg(8) + + var ss StreamState + fs.FastState(&ss) + require_Equal(t, ss.FirstSeq, 3) + require_Equal(t, ss.LastSeq, 10) + require_Equal(t, ss.Msgs, 7) + + // Compact last block. + fs.mu.RLock() + lmb := fs.lmb + fs.mu.RUnlock() + lmb.mu.Lock() + lmb.compact() + lmb.mu.Unlock() + // Stop but remove index.db + sfile := filepath.Join(sd, msgDir, streamStreamStateFile) + fs.Stop() + os.Remove(sfile) + + fs, err = newFileStore( + FileStoreConfig{StoreDir: sd}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage}) + require_NoError(t, err) + defer fs.Stop() + + fs.FastState(&ss) + require_Equal(t, ss.FirstSeq, 3) + require_Equal(t, ss.LastSeq, 10) + require_Equal(t, ss.Msgs, 7) +} + /////////////////////////////////////////////////////////////////////////// // Benchmarks /////////////////////////////////////////////////////////////////////////// diff --git a/server/jetstream_cluster.go b/server/jetstream_cluster.go index 6b62ed0939f..614fb487d4f 100644 --- a/server/jetstream_cluster.go +++ b/server/jetstream_cluster.go @@ -3571,7 +3571,6 @@ func (js *jetStream) processClusterUpdateStream(acc *Account, osa, sa *streamAss var needsSetLeader bool if !alreadyRunning && numReplicas > 1 { if needsNode { - mset.setLeader(false) js.createRaftGroup(acc.GetName(), rg, storage, pprofLabels{ "type": "stream", "account": mset.accName(), @@ -3591,10 +3590,14 @@ func (js *jetStream) processClusterUpdateStream(acc *Account, osa, sa *streamAss } else if numReplicas == 1 && alreadyRunning { // We downgraded to R1. Make sure we cleanup the raft node and the stream monitor. mset.removeNode() - // Make sure we are leader now that we are R1. - needsSetLeader = true // In case we need to shutdown the cluster specific subs, etc. - mset.setLeader(false) + mset.mu.Lock() + // Stop responding to sync requests. + mset.stopClusterSubs() + // Clear catchup state + mset.clearAllCatchupPeers() + mset.mu.Unlock() + // Remove from meta layer. js.mu.Lock() rg.node = nil js.mu.Unlock() @@ -4783,9 +4786,9 @@ func (js *jetStream) monitorConsumer(o *consumer, ca *consumerAssignment) { o.checkStateForInterestStream() // Do a snapshot. doSnapshot(true) - // Synchronize followers to our state. Only send out if we have state. + // Synchronize followers to our state. Only send out if we have state and nothing pending. if n != nil { - if _, _, applied := n.Progress(); applied > 0 { + if _, _, applied := n.Progress(); applied > 0 && aq.len() == 0 { if snap, err := o.store.EncodedState(); err == nil { n.SendSnapshot(snap) } @@ -5008,6 +5011,13 @@ var errConsumerClosed = errors.New("consumer closed") func (o *consumer) processReplicatedAck(dseq, sseq uint64) error { o.mu.Lock() + // Update activity. + o.lat = time.Now() + + // Do actual ack update to store. + // Always do this to have it recorded. + o.store.UpdateAcks(dseq, sseq) + mset := o.mset if o.closed || mset == nil { o.mu.Unlock() @@ -5018,11 +5028,11 @@ func (o *consumer) processReplicatedAck(dseq, sseq uint64) error { return errStreamClosed } - // Update activity. - o.lat = time.Now() - - // Do actual ack update to store. - o.store.UpdateAcks(dseq, sseq) + // Check if we have a reply that was requested. + if reply := o.replies[sseq]; reply != _EMPTY_ { + o.outq.sendMsg(reply, nil) + delete(o.replies, sseq) + } if o.retention == LimitsPolicy { o.mu.Unlock() @@ -7654,7 +7664,7 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ s, js, jsa, st, r, tierName, outq, node := mset.srv, mset.js, mset.jsa, mset.cfg.Storage, mset.cfg.Replicas, mset.tier, mset.outq, mset.node maxMsgSize, lseq := int(mset.cfg.MaxMsgSize), mset.lseq interestPolicy, discard, maxMsgs, maxBytes := mset.cfg.Retention != LimitsPolicy, mset.cfg.Discard, mset.cfg.MaxMsgs, mset.cfg.MaxBytes - isLeader, isSealed := mset.isLeader(), mset.cfg.Sealed + isLeader, isSealed, compressOK := mset.isLeader(), mset.cfg.Sealed, mset.compressOK mset.mu.RUnlock() // This should not happen but possible now that we allow scale up, and scale down where this could trigger. @@ -7842,7 +7852,7 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ } } - esm := encodeStreamMsgAllowCompress(subject, reply, hdr, msg, mset.clseq, time.Now().UnixNano(), mset.compressOK) + esm := encodeStreamMsgAllowCompress(subject, reply, hdr, msg, mset.clseq, time.Now().UnixNano(), compressOK) // Do proposal. err := node.Propose(esm) if err == nil { diff --git a/server/jetstream_cluster_1_test.go b/server/jetstream_cluster_1_test.go index 767290c9c64..6b6ade74448 100644 --- a/server/jetstream_cluster_1_test.go +++ b/server/jetstream_cluster_1_test.go @@ -21,6 +21,7 @@ import ( "context" crand "crypto/rand" "encoding/json" + "errors" "fmt" "math/rand" "os" @@ -4061,8 +4062,10 @@ func TestJetStreamClusterScaleConsumer(t *testing.T) { checkFor(t, time.Second*30, time.Millisecond*250, func() error { if ci, err = js.ConsumerInfo("TEST", "DUR"); err != nil { return err + } else if ci.Cluster == nil { + return errors.New("no cluster info") } else if ci.Cluster.Leader == _EMPTY_ { - return fmt.Errorf("no leader") + return errors.New("no leader") } else if len(ci.Cluster.Replicas) != r-1 { return fmt.Errorf("not enough replica, got %d wanted %d", len(ci.Cluster.Replicas), r-1) } else { diff --git a/server/jetstream_cluster_2_test.go b/server/jetstream_cluster_2_test.go index 5927f6a0e08..099764e8704 100644 --- a/server/jetstream_cluster_2_test.go +++ b/server/jetstream_cluster_2_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2023 The NATS Authors +// Copyright 2020-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -2672,8 +2672,10 @@ func TestJetStreamClusterStreamCatchupNoState(t *testing.T) { // For both make sure we have no raft snapshots. snapDir := filepath.Join(lconfig.StoreDir, "$SYS", "_js_", gname, "snapshots") os.RemoveAll(snapDir) - snapDir = filepath.Join(config.StoreDir, "$SYS", "_js_", gname, "snapshots") - os.RemoveAll(snapDir) + // Remove all our raft state, we do not want to hold onto our term and index which + // results in a coin toss for who becomes the leader. + raftDir := filepath.Join(config.StoreDir, "$SYS", "_js_", gname) + os.RemoveAll(raftDir) // Now restart. c.restartAll() diff --git a/server/jetstream_cluster_4_test.go b/server/jetstream_cluster_4_test.go index d10fc24bcbf..94c605bda73 100644 --- a/server/jetstream_cluster_4_test.go +++ b/server/jetstream_cluster_4_test.go @@ -26,6 +26,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -1006,3 +1007,690 @@ func TestClusteredInterestConsumerFilterEdit(t *testing.T) { t.Fatalf("expected 1 message got %d", nfo.State.Msgs) } } + +func TestJetStreamClusterDoubleAckRedelivery(t *testing.T) { + conf := ` + listen: 127.0.0.1:-1 + server_name: %s + jetstream: { + store_dir: '%s', + } + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + server_tags: ["test"] + system_account: sys + no_auth_user: js + accounts { + sys { users = [ { user: sys, pass: sys } ] } + js { + jetstream = enabled + users = [ { user: js, pass: js } ] + } + }` + c := createJetStreamClusterWithTemplate(t, conf, "R3F", 3) + defer c.shutdown() + for _, s := range c.servers { + s.optsMu.Lock() + s.opts.LameDuckDuration = 15 * time.Second + s.opts.LameDuckGracePeriod = -15 * time.Second + s.optsMu.Unlock() + } + s := c.randomNonLeader() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + + sc, err := js.AddStream(&nats.StreamConfig{ + Name: "LIMITS", + Subjects: []string{"foo.>"}, + Replicas: 3, + Storage: nats.FileStorage, + }) + require_NoError(t, err) + + stepDown := func() { + _, err = nc.Request(fmt.Sprintf(JSApiStreamLeaderStepDownT, sc.Config.Name), nil, time.Second) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + producer := func(name string) { + wg.Add(1) + nc, js := jsClientConnect(t, s) + defer nc.Close() + defer wg.Done() + + i := 0 + payload := []byte(strings.Repeat("Z", 1024)) + for range time.NewTicker(1 * time.Millisecond).C { + select { + case <-ctx.Done(): + return + default: + } + msgID := nats.MsgId(fmt.Sprintf("%s:%d", name, i)) + js.PublishAsync("foo.bar", payload, msgID, nats.RetryAttempts(10)) + i++ + } + } + go producer("A") + go producer("B") + go producer("C") + + sub, err := js.PullSubscribe("foo.bar", "ABC", nats.AckWait(5*time.Second), nats.MaxAckPending(1000), nats.PullMaxWaiting(1000)) + if err != nil { + t.Fatal(err) + } + + type ackResult struct { + ack *nats.Msg + original *nats.Msg + redelivered *nats.Msg + } + received := make(map[string]int64) + acked := make(map[string]*ackResult) + errors := make(map[string]error) + extraRedeliveries := 0 + + wg.Add(1) + go func() { + nc, js = jsClientConnect(t, s) + defer nc.Close() + defer wg.Done() + + fetch := func(t *testing.T, batchSize int) { + msgs, err := sub.Fetch(batchSize, nats.MaxWait(500*time.Millisecond)) + if err != nil { + return + } + + for _, msg := range msgs { + meta, err := msg.Metadata() + if err != nil { + t.Error(err) + continue + } + + msgID := msg.Header.Get(nats.MsgIdHdr) + if meta.NumDelivered > 1 { + if err, ok := errors[msgID]; ok { + t.Logf("Redelivery after failed Ack Sync: %+v - %+v - error: %v", msg.Reply, msg.Header, err) + } else { + t.Logf("Redelivery: %+v - %+v", msg.Reply, msg.Header) + } + if resp, ok := acked[msgID]; ok { + t.Errorf("Redelivery after successful Ack Sync: msgID:%v - redelivered:%v - original:%+v - ack:%+v", + msgID, msg.Reply, resp.original.Reply, resp.ack) + resp.redelivered = msg + extraRedeliveries++ + } + } + received[msgID]++ + resp, err := nc.Request(msg.Reply, []byte("+ACK"), 500*time.Millisecond) + if err != nil { + errors[msgID] = err + } else { + acked[msgID] = &ackResult{resp, msg, nil} + } + } + } + + for { + select { + case <-ctx.Done(): + return + default: + } + fetch(t, 1) + fetch(t, 50) + } + }() + + // Cause a couple of step downs before the restarts as well. + time.AfterFunc(5*time.Second, func() { stepDown() }) + time.AfterFunc(10*time.Second, func() { stepDown() }) + + // Let messages be produced, and then restart the servers. + <-time.After(15 * time.Second) + +NextServer: + for _, s := range c.servers { + s.lameDuckMode() + s.WaitForShutdown() + s = c.restartServer(s) + + hctx, hcancel := context.WithTimeout(ctx, 60*time.Second) + defer hcancel() + for range time.NewTicker(2 * time.Second).C { + select { + case <-hctx.Done(): + t.Logf("WRN: Timed out waiting for healthz from %s", s) + continue NextServer + default: + } + + status := s.healthz(nil) + if status.StatusCode == 200 { + continue NextServer + } + } + // Pause in-between server restarts. + time.Sleep(10 * time.Second) + } + + // Stop all producer and consumer goroutines to check results. + cancel() + select { + case <-ctx.Done(): + case <-time.After(10 * time.Second): + } + wg.Wait() + if extraRedeliveries > 0 { + t.Fatalf("Received %v redeliveries after a successful ack", extraRedeliveries) + } +} + +func TestJetStreamClusterBusyStreams(t *testing.T) { + t.Skip("Too long for CI at the moment") + type streamSetup struct { + config *nats.StreamConfig + consumers []*nats.ConsumerConfig + subjects []string + } + type job func(t *testing.T, nc *nats.Conn, js nats.JetStreamContext, c *cluster) + type testParams struct { + cluster string + streams []*streamSetup + producers int + consumers int + restartAny bool + restartWait time.Duration + ldmRestart bool + rolloutRestart bool + restarts int + checkHealthz bool + jobs []job + expect job + duration time.Duration + producerMsgs int + producerMsgSize int + } + test := func(t *testing.T, test *testParams) { + conf := ` + listen: 127.0.0.1:-1 + http: 127.0.0.1:-1 + server_name: %s + jetstream: { + domain: "cloud" + store_dir: '%s', + } + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + server_tags: ["test"] + system_account: sys + + no_auth_user: js + accounts { + sys { users = [ { user: sys, pass: sys } ] } + + js { jetstream = enabled + users = [ { user: js, pass: js } ] + } + }` + c := createJetStreamClusterWithTemplate(t, conf, test.cluster, 3) + defer c.shutdown() + for _, s := range c.servers { + s.optsMu.Lock() + s.opts.LameDuckDuration = 15 * time.Second + s.opts.LameDuckGracePeriod = -15 * time.Second + s.optsMu.Unlock() + } + + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + + var wg sync.WaitGroup + for _, stream := range test.streams { + stream := stream + wg.Add(1) + go func() { + defer wg.Done() + _, err := js.AddStream(stream.config) + require_NoError(t, err) + + for _, consumer := range stream.consumers { + _, err := js.AddConsumer(stream.config.Name, consumer) + require_NoError(t, err) + } + }() + } + wg.Wait() + + ctx, cancel := context.WithTimeout(context.Background(), test.duration) + defer cancel() + for _, stream := range test.streams { + payload := []byte(strings.Repeat("A", test.producerMsgSize)) + stream := stream + subjects := stream.subjects + + // Create publishers on different connections that sends messages + // to all the consumers subjects. + var n atomic.Uint64 + for i := 0; i < test.producers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + + for range time.NewTicker(1 * time.Millisecond).C { + select { + case <-ctx.Done(): + return + default: + } + + for _, subject := range subjects { + _, err := js.Publish(subject, payload, nats.AckWait(200*time.Millisecond)) + if err == nil { + if nn := n.Add(1); int(nn) >= test.producerMsgs { + return + } + } + } + } + }() + } + + // Create multiple parallel pull subscribers per consumer config. + for i := 0; i < test.consumers; i++ { + for _, consumer := range stream.consumers { + wg.Add(1) + + consumer := consumer + go func() { + defer wg.Done() + + for attempts := 0; attempts < 60; attempts++ { + _, err := js.ConsumerInfo(stream.config.Name, consumer.Name) + if err != nil { + t.Logf("WRN: Failed creating pull subscriber: %v - %v - %v - %v", + consumer.FilterSubject, stream.config.Name, consumer.Name, err) + } + } + sub, err := js.PullSubscribe(consumer.FilterSubject, "", nats.Bind(stream.config.Name, consumer.Name)) + if err != nil { + t.Logf("WRN: Failed creating pull subscriber: %v - %v - %v - %v", + consumer.FilterSubject, stream.config.Name, consumer.Name, err) + return + } + require_NoError(t, err) + + for range time.NewTicker(100 * time.Millisecond).C { + select { + case <-ctx.Done(): + return + default: + } + + msgs, err := sub.Fetch(1, nats.MaxWait(200*time.Millisecond)) + if err != nil { + continue + } + for _, msg := range msgs { + msg.Ack() + } + + msgs, err = sub.Fetch(100, nats.MaxWait(200*time.Millisecond)) + if err != nil { + continue + } + for _, msg := range msgs { + msg.Ack() + } + } + }() + } + } + } + + for _, job := range test.jobs { + go job(t, nc, js, c) + } + if test.restarts > 0 { + wg.Add(1) + time.AfterFunc(test.restartWait, func() { + defer wg.Done() + for i := 0; i < test.restarts; i++ { + switch { + case test.restartAny: + s := c.servers[rand.Intn(len(c.servers))] + if test.ldmRestart { + s.lameDuckMode() + } else { + s.Shutdown() + } + s.WaitForShutdown() + c.restartServer(s) + case test.rolloutRestart: + for _, s := range c.servers { + if test.ldmRestart { + s.lameDuckMode() + } else { + s.Shutdown() + } + s.WaitForShutdown() + s = c.restartServer(s) + + if test.checkHealthz { + hctx, hcancel := context.WithTimeout(ctx, 15*time.Second) + defer hcancel() + + Healthz: + for range time.NewTicker(2 * time.Second).C { + select { + case <-hctx.Done(): + break Healthz + default: + } + + status := s.healthz(nil) + if status.StatusCode == 200 { + break Healthz + } + } + } + } + } + c.waitOnClusterReady() + } + }) + } + test.expect(t, nc, js, c) + cancel() + wg.Wait() + } + stepDown := func(nc *nats.Conn, streamName string) { + nc.Request(fmt.Sprintf(JSApiStreamLeaderStepDownT, streamName), nil, time.Second) + } + getStreamDetails := func(t *testing.T, c *cluster, accountName, streamName string) *StreamDetail { + t.Helper() + srv := c.streamLeader(accountName, streamName) + jsz, err := srv.Jsz(&JSzOptions{Accounts: true, Streams: true, Consumer: true}) + require_NoError(t, err) + for _, acc := range jsz.AccountDetails { + if acc.Name == accountName { + for _, stream := range acc.Streams { + if stream.Name == streamName { + return &stream + } + } + } + } + t.Error("Could not find account details") + return nil + } + checkMsgsEqual := func(t *testing.T, c *cluster, accountName, streamName string) { + state := getStreamDetails(t, c, accountName, streamName).State + var msets []*stream + for _, s := range c.servers { + acc, err := s.LookupAccount(accountName) + require_NoError(t, err) + mset, err := acc.lookupStream(streamName) + require_NoError(t, err) + msets = append(msets, mset) + } + for seq := state.FirstSeq; seq <= state.LastSeq; seq++ { + var msgId string + var smv StoreMsg + for _, mset := range msets { + mset.mu.RLock() + sm, err := mset.store.LoadMsg(seq, &smv) + mset.mu.RUnlock() + require_NoError(t, err) + if msgId == _EMPTY_ { + msgId = string(sm.hdr) + } else if msgId != string(sm.hdr) { + t.Fatalf("MsgIds do not match for seq %d: %q vs %q", seq, msgId, sm.hdr) + } + } + } + } + checkConsumer := func(t *testing.T, c *cluster, accountName, streamName, consumerName string) { + t.Helper() + var leader string + for _, s := range c.servers { + jsz, err := s.Jsz(&JSzOptions{Accounts: true, Streams: true, Consumer: true}) + require_NoError(t, err) + for _, acc := range jsz.AccountDetails { + if acc.Name == accountName { + for _, stream := range acc.Streams { + if stream.Name == streamName { + for _, consumer := range stream.Consumer { + if leader == "" { + leader = consumer.Cluster.Leader + } else if leader != consumer.Cluster.Leader { + t.Errorf("There are two leaders for %s/%s: %s vs %s", + stream.Name, consumer.Name, leader, consumer.Cluster.Leader) + } + } + } + } + } + } + } + } + + t.Run("R1F/rescale/R3F/sources:10/limits", func(t *testing.T) { + testDuration := 3 * time.Minute + totalStreams := 10 + streams := make([]*streamSetup, totalStreams) + sources := make([]*nats.StreamSource, totalStreams) + for i := 0; i < totalStreams; i++ { + name := fmt.Sprintf("test:%d", i) + st := &streamSetup{ + config: &nats.StreamConfig{ + Name: name, + Subjects: []string{fmt.Sprintf("test.%d.*", i)}, + Replicas: 1, + Retention: nats.LimitsPolicy, + }, + } + st.subjects = append(st.subjects, fmt.Sprintf("test.%d.0", i)) + sources[i] = &nats.StreamSource{Name: name} + streams[i] = st + } + + // Create Source consumer. + sourceSetup := &streamSetup{ + config: &nats.StreamConfig{ + Name: "source-test", + Replicas: 1, + Retention: nats.LimitsPolicy, + Sources: sources, + }, + consumers: make([]*nats.ConsumerConfig, 0), + } + cc := &nats.ConsumerConfig{ + Name: "A", + Durable: "A", + FilterSubject: "test.>", + AckPolicy: nats.AckExplicitPolicy, + } + sourceSetup.consumers = append(sourceSetup.consumers, cc) + streams = append(streams, sourceSetup) + + scale := func(replicas int, wait time.Duration) job { + return func(t *testing.T, nc *nats.Conn, js nats.JetStreamContext, c *cluster) { + config := sourceSetup.config + time.AfterFunc(wait, func() { + config.Replicas = replicas + for i := 0; i < 10; i++ { + _, err := js.UpdateStream(config) + if err == nil { + return + } + time.Sleep(1 * time.Second) + } + }) + } + } + + expect := func(t *testing.T, nc *nats.Conn, js nats.JetStreamContext, c *cluster) { + // The source stream should not be stuck or be different from the other streams. + time.Sleep(testDuration + 1*time.Minute) + accName := "js" + streamName := "source-test" + + // Check a few times to see if there are no changes in the number of messages. + var changed bool + var prevMsgs uint64 + for i := 0; i < 10; i++ { + sinfo, err := js.StreamInfo(streamName) + if err != nil { + t.Logf("Error: %v", err) + time.Sleep(2 * time.Second) + continue + } + prevMsgs = sinfo.State.Msgs + } + for i := 0; i < 10; i++ { + sinfo, err := js.StreamInfo(streamName) + if err != nil { + t.Logf("Error: %v", err) + time.Sleep(2 * time.Second) + continue + } + changed = prevMsgs != sinfo.State.Msgs + prevMsgs = sinfo.State.Msgs + time.Sleep(2 * time.Second) + } + if !changed { + // Doing a leader step down should not cause the messages to change. + stepDown(nc, streamName) + + for i := 0; i < 10; i++ { + sinfo, err := js.StreamInfo(streamName) + if err != nil { + t.Logf("Error: %v", err) + time.Sleep(2 * time.Second) + continue + } + changed = prevMsgs != sinfo.State.Msgs + prevMsgs = sinfo.State.Msgs + time.Sleep(2 * time.Second) + } + if changed { + t.Error("Stream msgs changed after the step down") + } + } + + ///////////////////////////////////////////////////////////////////////////////////////// + // // + // The number of messages sourced should match the count from all the other streams. // + // // + ///////////////////////////////////////////////////////////////////////////////////////// + var expectedMsgs uint64 + for i := 0; i < totalStreams; i++ { + name := fmt.Sprintf("test:%d", i) + sinfo, err := js.StreamInfo(name) + require_NoError(t, err) + expectedMsgs += sinfo.State.Msgs + } + sinfo, err := js.StreamInfo(streamName) + require_NoError(t, err) + + gotMsgs := sinfo.State.Msgs + if gotMsgs != expectedMsgs { + t.Errorf("stream with sources has %v messages, but total sourced messages should be %v", gotMsgs, expectedMsgs) + } + checkConsumer(t, c, accName, streamName, "A") + checkMsgsEqual(t, c, accName, streamName) + } + test(t, &testParams{ + cluster: t.Name(), + streams: streams, + producers: 10, + consumers: 10, + restarts: 1, + rolloutRestart: true, + ldmRestart: true, + checkHealthz: true, + // TODO(dlc) - If this overlaps with the scale jobs this test will fail. + // Leaders will be elected with partial state. + restartWait: 65 * time.Second, + jobs: []job{ + scale(3, 15*time.Second), + scale(1, 30*time.Second), + scale(3, 60*time.Second), + }, + expect: expect, + duration: testDuration, + producerMsgSize: 1024, + producerMsgs: 100_000, + }) + }) + + t.Run("R3F/streams:30/limits", func(t *testing.T) { + testDuration := 3 * time.Minute + totalStreams := 30 + consumersPerStream := 5 + streams := make([]*streamSetup, totalStreams) + for i := 0; i < totalStreams; i++ { + name := fmt.Sprintf("test:%d", i) + st := &streamSetup{ + config: &nats.StreamConfig{ + Name: name, + Subjects: []string{fmt.Sprintf("test.%d.*", i)}, + Replicas: 3, + Retention: nats.LimitsPolicy, + }, + consumers: make([]*nats.ConsumerConfig, 0), + } + for j := 0; j < consumersPerStream; j++ { + subject := fmt.Sprintf("test.%d.%d", i, j) + name := fmt.Sprintf("A:%d:%d", i, j) + cc := &nats.ConsumerConfig{ + Name: name, + Durable: name, + FilterSubject: subject, + AckPolicy: nats.AckExplicitPolicy, + } + st.consumers = append(st.consumers, cc) + st.subjects = append(st.subjects, subject) + } + streams[i] = st + } + expect := func(t *testing.T, nc *nats.Conn, js nats.JetStreamContext, c *cluster) { + time.Sleep(testDuration + 1*time.Minute) + accName := "js" + for i := 0; i < totalStreams; i++ { + streamName := fmt.Sprintf("test:%d", i) + checkMsgsEqual(t, c, accName, streamName) + } + } + test(t, &testParams{ + cluster: t.Name(), + streams: streams, + producers: 10, + consumers: 10, + restarts: 1, + rolloutRestart: true, + ldmRestart: true, + checkHealthz: true, + restartWait: 45 * time.Second, + expect: expect, + duration: testDuration, + producerMsgSize: 1024, + producerMsgs: 100_000, + }) + }) +} diff --git a/server/leafnode.go b/server/leafnode.go index 8f3fe627e46..67a71590d6b 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -981,6 +981,7 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf c.Noticef("Leafnode connection created%s %s", remoteSuffix, c.opts.Name) var tlsFirst bool + var infoTimeout time.Duration if remote != nil { solicited = true remote.Lock() @@ -990,6 +991,7 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf c.leaf.isSpoke = true } tlsFirst = remote.TLSHandshakeFirst + infoTimeout = remote.FirstInfoTimeout remote.Unlock() c.acc = acc } else { @@ -1047,7 +1049,7 @@ func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCf } } // We need to wait for the info, but not for too long. - c.nc.SetReadDeadline(time.Now().Add(DEFAULT_LEAFNODE_INFO_WAIT)) + c.nc.SetReadDeadline(time.Now().Add(infoTimeout)) } // We will process the INFO from the readloop and finish by @@ -2818,6 +2820,7 @@ func (c *client) leafNodeSolicitWSConnection(opts *Options, rURL *url.URL, remot compress := remote.Websocket.Compression // By default the server will mask outbound frames, but it can be disabled with this option. noMasking := remote.Websocket.NoMasking + infoTimeout := remote.FirstInfoTimeout remote.RUnlock() // Will do the client-side TLS handshake if needed. tlsRequired, err := c.leafClientHandshakeIfNeeded(remote, opts) @@ -2870,6 +2873,7 @@ func (c *client) leafNodeSolicitWSConnection(opts *Options, rURL *url.URL, remot if noMasking { req.Header.Add(wsNoMaskingHeader, wsNoMaskingValue) } + c.nc.SetDeadline(time.Now().Add(infoTimeout)) if err := req.Write(c.nc); err != nil { return nil, WriteError, err } @@ -2877,7 +2881,6 @@ func (c *client) leafNodeSolicitWSConnection(opts *Options, rURL *url.URL, remot var resp *http.Response br := bufio.NewReaderSize(c.nc, MAX_CONTROL_LINE_SIZE) - c.nc.SetReadDeadline(time.Now().Add(DEFAULT_LEAFNODE_INFO_WAIT)) resp, err = http.ReadResponse(br, req) if err == nil && (resp.StatusCode != 101 || diff --git a/server/leafnode_test.go b/server/leafnode_test.go index aaabb988ef4..c46067ed129 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -7607,3 +7607,147 @@ func TestLeafNodeLoopDetectionOnActualLoop(t *testing.T) { t.Fatalf("Did not get any error regarding loop") } } + +func TestLeafNodeConnectionSucceedsEvenWithDelayedFirstINFO(t *testing.T) { + for _, test := range []struct { + name string + websocket bool + }{ + {"regular", false}, + {"websocket", true}, + } { + t.Run(test.name, func(t *testing.T) { + ob := DefaultOptions() + ob.ServerName = "HUB" + ob.LeafNode.Host = "127.0.0.1" + ob.LeafNode.Port = -1 + ob.LeafNode.AuthTimeout = 10 + if test.websocket { + ob.Websocket.Host = "127.0.0.1" + ob.Websocket.Port = -1 + ob.Websocket.HandshakeTimeout = 10 * time.Second + ob.Websocket.AuthTimeout = 10 + ob.Websocket.NoTLS = true + } + sb := RunServer(ob) + defer sb.Shutdown() + + var port int + var scheme string + if test.websocket { + port = ob.Websocket.Port + scheme = wsSchemePrefix + } else { + port = ob.LeafNode.Port + scheme = "nats" + } + + urlStr := fmt.Sprintf("%s://127.0.0.1:%d", scheme, port) + proxy := createNetProxy(1100*time.Millisecond, 1024*1024*1024, 1024*1024*1024, urlStr, true) + defer proxy.stop() + proxyURL := proxy.clientURL() + _, proxyPort, err := net.SplitHostPort(proxyURL[len(scheme)+3:]) + require_NoError(t, err) + + lnBURL, err := url.Parse(fmt.Sprintf("%s://127.0.0.1:%s", scheme, proxyPort)) + require_NoError(t, err) + + oa := DefaultOptions() + oa.ServerName = "SPOKE" + oa.Cluster.Name = "xyz" + remote := &RemoteLeafOpts{ + URLs: []*url.URL{lnBURL}, + FirstInfoTimeout: 3 * time.Second, + } + oa.LeafNode.Remotes = []*RemoteLeafOpts{remote} + sa := RunServer(oa) + defer sa.Shutdown() + + checkLeafNodeConnected(t, sa) + }) + } +} + +type captureLeafConnClosed struct { + DummyLogger + ch chan struct{} +} + +func (l *captureLeafConnClosed) Noticef(format string, v ...any) { + msg := fmt.Sprintf(format, v...) + if strings.Contains(msg, "Leafnode connection closed: Read Error") { + select { + case l.ch <- struct{}{}: + default: + } + } +} + +func TestLeafNodeDetectsStaleConnectionIfNoInfo(t *testing.T) { + for _, test := range []struct { + name string + websocket bool + }{ + {"regular", false}, + {"websocket", true}, + } { + t.Run(test.name, func(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require_NoError(t, err) + defer l.Close() + + ch := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + c, err := l.Accept() + if err != nil { + return + } + defer c.Close() + <-ch + }() + + var scheme string + if test.websocket { + scheme = wsSchemePrefix + } else { + scheme = "nats" + } + urlStr := fmt.Sprintf("%s://%s", scheme, l.Addr()) + lnBURL, err := url.Parse(urlStr) + require_NoError(t, err) + + oa := DefaultOptions() + oa.ServerName = "SPOKE" + oa.Cluster.Name = "xyz" + remote := &RemoteLeafOpts{ + URLs: []*url.URL{lnBURL}, + FirstInfoTimeout: 250 * time.Millisecond, + } + oa.LeafNode.Remotes = []*RemoteLeafOpts{remote} + oa.DisableShortFirstPing = false + oa.NoLog = false + sa, err := NewServer(oa) + require_NoError(t, err) + defer sa.Shutdown() + + log := &captureLeafConnClosed{ch: make(chan struct{}, 1)} + sa.SetLogger(log, false, false) + sa.Start() + + select { + case <-log.ch: + // OK + case <-time.After(750 * time.Millisecond): + t.Fatalf("Connection was not closed") + } + + sa.Shutdown() + close(ch) + wg.Wait() + sa.WaitForShutdown() + }) + } +} diff --git a/server/opts.go b/server/opts.go index ac3988235fb..ad18f232f39 100644 --- a/server/opts.go +++ b/server/opts.go @@ -205,6 +205,11 @@ type RemoteLeafOpts struct { DenyImports []string `json:"-"` DenyExports []string `json:"-"` + // FirstInfoTimeout is the amount of time the server will wait for the + // initial INFO protocol from the remote server before closing the + // connection. + FirstInfoTimeout time.Duration `json:"-"` + // Compression options for this remote. Each remote could have a different // setting and also be different from the LeafNode options. Compression CompressionOpts `json:"-"` @@ -2581,6 +2586,8 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL *errors = append(*errors, err) continue } + case "first_info_timeout": + remote.FirstInfoTimeout = parseDuration(k, tk, v, errors, warnings) default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -5108,6 +5115,10 @@ func setBaselineOptions(opts *Options) { c.Mode = CompressionS2Auto } } + // Set default first info timeout value if not set. + if r.FirstInfoTimeout <= 0 { + r.FirstInfoTimeout = DEFAULT_LEAFNODE_INFO_WAIT + } } } diff --git a/server/raft.go b/server/raft.go index aff4ff8c9a7..975ebd068bf 100644 --- a/server/raft.go +++ b/server/raft.go @@ -470,6 +470,11 @@ func (s *Server) startRaftNode(accName string, cfg *RaftConfig, labels pprofLabe } } } + } else if n.pterm == 0 && n.pindex == 0 { + // We have recovered no state, either through our WAL or snapshots, + // so inherit from term from our tav.idx file and pindex from our last sequence. + n.pterm = n.term + n.pindex = state.LastSeq } // Make sure to track ourselves. @@ -3750,11 +3755,11 @@ func (n *raft) readTermVote() (term uint64, voted string, err error) { if err != nil { return 0, noVote, err } - if len(buf) < termVoteLen { - return 0, noVote, nil - } var le = binary.LittleEndian term = le.Uint64(buf[0:]) + if len(buf) < termVoteLen { + return term, noVote, nil + } voted = string(buf[8:]) return term, voted, nil } @@ -3820,6 +3825,8 @@ func (n *raft) writeTermVote() { // Stamp latest and write the term & vote file. n.wtv = b if err := writeTermVote(n.sd, n.wtv); err != nil && !n.isClosed() { + // Clear wtv since we failed. + n.wtv = nil n.setWriteErrLocked(err) n.warn("Error writing term and vote file for %q: %v", n.group, err) } diff --git a/server/server.go b/server/server.go index 89e9c7de23c..d10b26615cb 100644 --- a/server/server.go +++ b/server/server.go @@ -3048,7 +3048,16 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } now := time.Now() - c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now} + c := &client{ + srv: s, + nc: conn, + opts: defaultOpts, + mpay: maxPay, + msubs: maxSubs, + start: now, + last: now, + iproc: inProcess, + } c.registerWithAccount(s.globalAccount()) diff --git a/server/server_test.go b/server/server_test.go index 42b650ac69f..8226459c742 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,4 +1,4 @@ -// Copyright 2012-2020 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -2107,3 +2107,21 @@ func TestServerAuthBlockAndSysAccounts(t *testing.T) { _, err = nats.Connect(s.ClientURL()) require_Error(t, err, nats.ErrAuthorization, errors.New("nats: Authorization Violation")) } + +// https://github.com/nats-io/nats-server/issues/5396 +func TestServerConfigLastLineComments(t *testing.T) { + conf := createConfFile(t, []byte(` + { + "listen": "0.0.0.0:4222" + } + # wibble + `)) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // This should work of course. + nc, err := nats.Connect(s.ClientURL()) + require_NoError(t, err) + defer nc.Close() +} diff --git a/server/stream.go b/server/stream.go index f74082c68d0..518fb6599db 100644 --- a/server/stream.go +++ b/server/stream.go @@ -841,9 +841,10 @@ func (mset *stream) setLeader(isLeader bool) error { // Make sure we are listening for sync requests. // TODO(dlc) - Original design was that all in sync members of the group would do DQ. mset.startClusterSubs() - // Setup subscriptions + + // Setup subscriptions if we were not already the leader. if err := mset.subscribeToStream(); err != nil { - if isLeader && mset.isClustered() { + if mset.isClustered() { // Stepdown since we have an error. mset.node.StepDown() } @@ -2791,6 +2792,11 @@ func (mset *stream) cancelSourceInfo(si *sourceInfo) { si.msgs.drain() si.msgs.unregister() } + // If we have a schedule setup go ahead and delete that. + if t := mset.sourceSetupSchedules[si.iname]; t != nil { + t.Stop() + delete(mset.sourceSetupSchedules, si.iname) + } } const sourceConsumerRetryThreshold = 2 * time.Second @@ -3107,7 +3113,7 @@ func (mset *stream) processAllSourceMsgs() { for _, im := range ims { if !mset.processInboundSourceMsg(im.si, im) { // If we are no longer leader bail. - if !mset.isLeader() { + if !mset.IsLeader() { cleanUp() return } @@ -3118,7 +3124,7 @@ func (mset *stream) processAllSourceMsgs() { msgs.recycle(&ims) case <-t.C: // If we are no longer leader bail. - if !mset.isLeader() { + if !mset.IsLeader() { cleanUp() return } @@ -3182,15 +3188,14 @@ func (mset *stream) handleFlowControl(m *inMsg) { // processInboundSourceMsg handles processing other stream messages bound for this stream. func (mset *stream) processInboundSourceMsg(si *sourceInfo, m *inMsg) bool { + mset.mu.Lock() // If we are no longer the leader cancel this subscriber. if !mset.isLeader() { - mset.mu.Lock() mset.cancelSourceConsumer(si.iname) mset.mu.Unlock() return false } - mset.mu.Lock() isControl := m.isControlMsg() // Ignore from old subscriptions. @@ -3449,9 +3454,11 @@ func (mset *stream) setStartingSequenceForSources(iNames map[string]struct{}) { } } -// lock should be held. // Resets the SourceInfo for all the sources +// lock should be held. func (mset *stream) resetSourceInfo() { + // Reset if needed. + mset.stopSourceConsumers() mset.sources = make(map[string]*sourceInfo) for _, ssi := range mset.cfg.Sources { @@ -3617,7 +3624,7 @@ func (mset *stream) subscribeToStream() error { mset.mirror.trs = trs // delay the actual mirror consumer creation for after a delay mset.scheduleSetupMirrorConsumerRetry() - } else if len(mset.cfg.Sources) > 0 { + } else if len(mset.cfg.Sources) > 0 && mset.sourcesConsumerSetup == nil { // Setup the initial source infos for the sources mset.resetSourceInfo() // Delay the actual source consumer(s) creation(s) for after a delay diff --git a/server/websocket.go b/server/websocket.go index e026674d9f7..ef6b5169af8 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -125,12 +125,17 @@ type srvWebsocket struct { server *http.Server listener net.Listener listenerErr error - tls bool allowedOrigins map[string]*allowedOrigin // host will be the key sameOrigin bool connectURLs []string connectURLsMap refCountedUrlSet - authOverride bool // indicate if there is auth override in websocket config + authOverride bool // indicate if there is auth override in websocket config + + // These are immutable and can be accessed without lock. + // This is the case when generating the client INFO. + tls bool // True if TLS is required (TLSConfig is specified). + host string // Host/IP the webserver is listening on (shortcut to opts.Websocket.Host). + port int // Port the webserver is listening on. This is after an ephemeral port may have been selected (shortcut to opts.Websocket.Port). } type allowedOrigin struct { @@ -1102,7 +1107,12 @@ func (s *Server) startWebsocketServer() { s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!") } - s.websocket.tls = proto == "wss" + // These 3 are immutable and will be accessed without lock by the client + // when generating/sending the INFO protocols. + s.websocket.tls = proto == wsSchemePrefixTLS + s.websocket.host, s.websocket.port = o.Host, o.Port + + // This will be updated when/if the cluster changes. s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) if err != nil { s.Fatalf("Unable to get websocket connect URLs: %v", err) @@ -1141,8 +1151,10 @@ func (s *Server) startWebsocketServer() { ReadTimeout: o.HandshakeTimeout, ErrorLog: log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0), } + s.websocket.mu.Lock() s.websocket.server = hs s.websocket.listener = hl + s.websocket.mu.Unlock() go func() { if err := hs.Serve(hl); err != http.ErrServerClosed { s.Fatalf("websocket listener error: %v", err) diff --git a/server/websocket_test.go b/server/websocket_test.go index 4b87c8bb379..368a9fe8d6f 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -3852,59 +3852,135 @@ func TestWSJWTCookieUser(t *testing.T) { } func TestWSReloadTLSConfig(t *testing.T) { + tlsBlock := ` + tls { + cert_file: '%s' + key_file: '%s' + ca_file: '../test/configs/certs/ca.pem' + verify: %v + } + ` template := ` listen: "127.0.0.1:-1" websocket { listen: "127.0.0.1:-1" - tls { - cert_file: '%s' - key_file: '%s' - ca_file: '../test/configs/certs/ca.pem' - } + %s + no_tls: %v } ` conf := createConfFile(t, []byte(fmt.Sprintf(template, - "../test/configs/certs/server-noip.pem", - "../test/configs/certs/server-key-noip.pem"))) + fmt.Sprintf(tlsBlock, + "../test/configs/certs/server-noip.pem", + "../test/configs/certs/server-key-noip.pem", + false), false))) s, o := RunServerWithConfig(conf) defer s.Shutdown() addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) - wsc, err := net.Dial("tcp", addr) - if err != nil { - t.Fatalf("Error creating ws connection: %v", err) + + check := func(tlsConfig *tls.Config, handshakeFail bool, errTxt string) { + t.Helper() + + wsc, err := net.Dial("tcp", addr) + require_NoError(t, err) + defer wsc.Close() + + wsc = tls.Client(wsc, tlsConfig) + err = wsc.(*tls.Conn).Handshake() + if handshakeFail { + require_True(t, err != nil) + require_Contains(t, err.Error(), errTxt) + return + } + require_NoError(t, err) + + req := testWSCreateValidReq() + req.URL, _ = url.Parse(wsSchemePrefixTLS + "://" + addr) + err = req.Write(wsc) + require_NoError(t, err) + + br := bufio.NewReader(wsc) + resp, err := http.ReadResponse(br, req) + if errTxt == _EMPTY_ { + require_NoError(t, err) + } else { + require_True(t, err != nil) + require_Contains(t, err.Error(), errTxt) + return + } + defer resp.Body.Close() + l := testWSReadFrame(t, br) + require_True(t, bytes.HasPrefix(l, []byte("INFO {"))) + var info Info + err = json.Unmarshal(l[5:], &info) + require_NoError(t, err) + require_True(t, info.TLSAvailable) + require_True(t, info.TLSRequired) + require_Equal[string](t, info.Host, "127.0.0.1") + require_Equal[int](t, info.Port, o.Websocket.Port) } - defer wsc.Close() tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"} tlsConfig, err := GenTLSConfig(tc) - if err != nil { - t.Fatalf("Error generating TLS config: %v", err) - } + require_NoError(t, err) tlsConfig.ServerName = "127.0.0.1" tlsConfig.RootCAs = tlsConfig.ClientCAs tlsConfig.ClientCAs = nil - wsc = tls.Client(wsc, tlsConfig.Clone()) - if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") { - t.Fatalf("Unexpected error: %v", err) - } - wsc.Close() + // Handshake should fail with error regarding SANs + check(tlsConfig.Clone(), true, "SAN") + + // Replace certs with ones that allow IP. reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, - "../test/configs/certs/server-cert.pem", - "../test/configs/certs/server-key.pem")) + fmt.Sprintf(tlsBlock, + "../test/configs/certs/server-cert.pem", + "../test/configs/certs/server-key.pem", + false), false)) - wsc, err = net.Dial("tcp", addr) - if err != nil { - t.Fatalf("Error creating ws connection: %v", err) - } - defer wsc.Close() + // Connection should succeed + check(tlsConfig.Clone(), false, _EMPTY_) - wsc = tls.Client(wsc, tlsConfig.Clone()) - if err := wsc.(*tls.Conn).Handshake(); err != nil { - t.Fatalf("Error on TLS handshake: %v", err) + // Udpate config to require client cert. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, + fmt.Sprintf(tlsBlock, + "../test/configs/certs/server-cert.pem", + "../test/configs/certs/server-key.pem", + true), false)) + + // Connection should fail saying that a tls cert is required + check(tlsConfig.Clone(), false, "required") + + // Add a client cert + tc = &TLSConfigOpts{ + CertFile: "../test/configs/certs/client-cert.pem", + KeyFile: "../test/configs/certs/client-key.pem", } + tlsConfig, err = GenTLSConfig(tc) + require_NoError(t, err) + tlsConfig.InsecureSkipVerify = true + + // Connection should succeed + check(tlsConfig.Clone(), false, _EMPTY_) + + // Removing the tls{} block but with no_tls still false should fail + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, false))) + err = s.Reload() + require_True(t, err != nil) + require_Contains(t, err.Error(), "TLS configuration") + + // We should still be able to connect a TLS client + check(tlsConfig.Clone(), false, _EMPTY_) + + // Now remove the tls{} block and set no_tls: true and that should fail + // since this is not supported. + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, true))) + err = s.Reload() + require_True(t, err != nil) + require_Contains(t, err.Error(), "not supported") + + // We should still be able to connect a TLS client + check(tlsConfig.Clone(), false, _EMPTY_) } type captureClientConnectedLogger struct {