diff --git a/server/certstore/certstore.go b/server/certstore/certstore.go index 3d7dfde60fd..110ea85a7d7 100644 --- a/server/certstore/certstore.go +++ b/server/certstore/certstore.go @@ -46,11 +46,13 @@ type MatchByType int const ( matchByIssuer MatchByType = iota + 1 matchBySubject + matchByThumbprint ) var MatchByMap = map[string]MatchByType{ - "issuer": matchByIssuer, - "subject": matchBySubject, + "issuer": matchByIssuer, + "subject": matchBySubject, + "thumbprint": matchByThumbprint, } var Usage = ` diff --git a/server/certstore/certstore_other.go b/server/certstore/certstore_other.go index a72df834a1a..459b8db64a3 100644 --- a/server/certstore/certstore_other.go +++ b/server/certstore/certstore_other.go @@ -1,4 +1,4 @@ -// Copyright 2022-2023 The NATS Authors +// Copyright 2022-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -26,8 +26,7 @@ var _ = MATCHBYEMPTY // otherKey implements crypto.Signer and crypto.Decrypter to satisfy linter on platforms that don't implement certstore type otherKey struct{} -func TLSConfig(certStore StoreType, certMatchBy MatchByType, certMatch string, config *tls.Config) error { - _, _, _, _ = certStore, certMatchBy, certMatch, config +func TLSConfig(_ StoreType, _ MatchByType, _ string, _ []string, _ bool, _ *tls.Config) error { return ErrOSNotCompatCertStore } diff --git a/server/certstore/certstore_windows.go b/server/certstore/certstore_windows.go index 19b9567be73..d47adb6eea3 100644 --- a/server/certstore/certstore_windows.go +++ b/server/certstore/certstore_windows.go @@ -1,4 +1,4 @@ -// Copyright 2022-2023 The NATS Authors +// Copyright 2022-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -41,26 +41,26 @@ import ( const ( // wincrypt.h constants - winAcquireCached = 0x1 // CRYPT_ACQUIRE_CACHE_FLAG - winAcquireSilent = 0x40 // CRYPT_ACQUIRE_SILENT_FLAG - winAcquireOnlyNCryptKey = 0x40000 // CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG - winEncodingX509ASN = 1 // X509_ASN_ENCODING - winEncodingPKCS7 = 65536 // PKCS_7_ASN_ENCODING - winCertStoreProvSystem = 10 // CERT_STORE_PROV_SYSTEM - winCertStoreCurrentUser = uint32(winCertStoreCurrentUserID << winCompareShift) // CERT_SYSTEM_STORE_CURRENT_USER - winCertStoreLocalMachine = uint32(winCertStoreLocalMachineID << winCompareShift) // CERT_SYSTEM_STORE_LOCAL_MACHINE - winCertStoreCurrentUserID = 1 // CERT_SYSTEM_STORE_CURRENT_USER_ID - winCertStoreLocalMachineID = 2 // CERT_SYSTEM_STORE_LOCAL_MACHINE_ID - winInfoIssuerFlag = 4 // CERT_INFO_ISSUER_FLAG - winInfoSubjectFlag = 7 // CERT_INFO_SUBJECT_FLAG - winCompareNameStrW = 8 // CERT_COMPARE_NAME_STR_A - winCompareShift = 16 // CERT_COMPARE_SHIFT + winAcquireCached = windows.CRYPT_ACQUIRE_CACHE_FLAG + winAcquireSilent = windows.CRYPT_ACQUIRE_SILENT_FLAG + winAcquireOnlyNCryptKey = windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG + winEncodingX509ASN = windows.X509_ASN_ENCODING + winEncodingPKCS7 = windows.PKCS_7_ASN_ENCODING + winCertStoreProvSystem = windows.CERT_STORE_PROV_SYSTEM + winCertStoreCurrentUser = windows.CERT_SYSTEM_STORE_CURRENT_USER + winCertStoreLocalMachine = windows.CERT_SYSTEM_STORE_LOCAL_MACHINE + winCertStoreReadOnly = windows.CERT_STORE_READONLY_FLAG + winInfoIssuerFlag = windows.CERT_INFO_ISSUER_FLAG + winInfoSubjectFlag = windows.CERT_INFO_SUBJECT_FLAG + winCompareNameStrW = windows.CERT_COMPARE_NAME_STR_W + winCompareShift = windows.CERT_COMPARE_SHIFT // Reference https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore - winFindIssuerStr = winCompareNameStrW< 0 { + execArgs = append(execArgs, args...) + } cmdImport := &exec.Cmd{ Path: psExec, @@ -41,7 +44,7 @@ func runPowershellScript(scriptFile string, args []string) error { return cmdImport.Run() } -func runConfiguredLeaf(t *testing.T, hubPort int, certStore string, matchBy string, match string, expectedLeafCount int) { +func runConfiguredLeaf(t *testing.T, hubPort int, certStore string, matchBy string, match string, caMatch string, expectedLeafCount int) { // Fire up the leaf u, err := url.Parse(fmt.Sprintf("nats://localhost:%d", hubPort)) @@ -59,18 +62,18 @@ func runConfiguredLeaf(t *testing.T, hubPort int, certStore string, matchBy stri cert_store: "%s" cert_match_by: "%s" cert_match: "%s" + ca_certs_match: %s - # Above should be equivalent to: + # Test settings that succeed should be equivalent to: # cert_file: "../test/configs/certs/tlsauth/client.pem" # key_file: "../test/configs/certs/tlsauth/client-key.pem" - - ca_file: "../test/configs/certs/tlsauth/ca.pem" + # ca_file: "../test/configs/certs/tlsauth/ca.pem" timeout: 5 } } ] } - `, u.String(), certStore, matchBy, match) + `, u.String(), certStore, matchBy, match, caMatch) leafConfig := createConfFile(t, []byte(configStr)) defer removeFile(t, leafConfig) @@ -90,7 +93,7 @@ func runConfiguredLeaf(t *testing.T, hubPort int, certStore string, matchBy stri func TestLeafTLSWindowsCertStore(t *testing.T) { // Client Identity (client.pem) - // Issuer: O = Synadia Communications Inc., OU = NATS.io, CN = localhost + // Issuer: O = NATS CA, OU = NATS.io, CN = localhost // Subject: OU = NATS.io, CN = example.com // Make sure windows cert store is reset to avoid conflict with other tests @@ -105,6 +108,11 @@ func TestLeafTLSWindowsCertStore(t *testing.T) { t.Fatalf("expected powershell provision to succeed: %s", err.Error()) } + err = runPowershellScript("../test/configs/certs/tlsauth/certstore/import-p12-ca.ps1", nil) + if err != nil { + t.Fatalf("expected powershell provision CA to succeed: %s", err.Error()) + } + // Fire up the hub hubConfig := createConfFile(t, []byte(` port: -1 @@ -140,26 +148,39 @@ func TestLeafTLSWindowsCertStore(t *testing.T) { certStore string certMatchBy string certMatch string + caCertsMatch string expectedLeafCount int }{ - {"WindowsCurrentUser", "Subject", "example.com", 1}, - {"WindowsCurrentUser", "Issuer", "Synadia Communications Inc.", 1}, - {"WindowsCurrentUser", "Issuer", "Frodo Baggins, Inc.", 0}, + // Test subject and issuer + {"WindowsCurrentUser", "Subject", "example.com", "\"NATS CA\"", 1}, + {"WindowsCurrentUser", "Issuer", "NATS CA", "\"NATS CA\"", 1}, + {"WindowsCurrentUser", "Issuer", "Frodo Baggins, Inc.", "\"NATS CA\"", 0}, + {"WindowsCurrentUser", "Thumbprint", "7e44f478114a2e29b98b00beb1b3687d8dc0e481", "\"NATS CA\"", 0}, + // Test CAs, NATS CA is valid, others are missing + {"WindowsCurrentUser", "Subject", "example.com", "[\"NATS CA\"]", 1}, + {"WindowsCurrentUser", "Subject", "example.com", "[\"GlobalSign\"]", 0}, + {"WindowsCurrentUser", "Subject", "example.com", "[\"Missing NATS Cert\"]", 0}, + {"WindowsCurrentUser", "Subject", "example.com", "[\"NATS CA\", \"Missing NATS Cert1\"]", 1}, + {"WindowsCurrentUser", "Subject", "example.com", "[\"Missing Cert2\",\"NATS CA\"]", 1}, + {"WindowsCurrentUser", "Subject", "example.com", "[\"Missing, Cert3\",\"Missing NATS Cert4\"]", 0}, } for _, tc := range testCases { - t.Run(fmt.Sprintf("%s by %s match %s", tc.certStore, tc.certMatchBy, tc.certMatch), func(t *testing.T) { + testName := fmt.Sprintf("%s by %s match %s", tc.certStore, tc.certMatchBy, tc.certMatch) + t.Run(fmt.Sprintf(testName, tc.certStore, tc.certMatchBy, tc.certMatch, tc.caCertsMatch), func(t *testing.T) { defer func() { if r := recover(); r != nil { if tc.expectedLeafCount != 0 { - t.Fatalf("did not expect panic") + t.Fatalf("did not expect panic: %s", testName) } else { if !strings.Contains(fmt.Sprintf("%v", r), "Error processing configuration file") { - t.Fatalf("did not expect unknown panic cause") + t.Fatalf("did not expect unknown panic: %s", testName) } } } }() - runConfiguredLeaf(t, hubOptions.LeafNode.Port, tc.certStore, tc.certMatchBy, tc.certMatch, tc.expectedLeafCount) + runConfiguredLeaf(t, hubOptions.LeafNode.Port, + tc.certStore, tc.certMatchBy, tc.certMatch, + tc.caCertsMatch, tc.expectedLeafCount) }) } } @@ -169,7 +190,7 @@ func TestLeafTLSWindowsCertStore(t *testing.T) { func TestServerTLSWindowsCertStore(t *testing.T) { // Server Identity (server.pem) - // Issuer: O = Synadia Communications Inc., OU = NATS.io, CN = localhost + // Issuer: O = NATS CA, OU = NATS.io, CN = localhost // Subject: OU = NATS.io Operators, CN = localhost // Make sure windows cert store is reset to avoid conflict with other tests @@ -184,6 +205,11 @@ func TestServerTLSWindowsCertStore(t *testing.T) { t.Fatalf("expected powershell provision to succeed: %s", err.Error()) } + err = runPowershellScript("../test/configs/certs/tlsauth/certstore/import-p12-ca.ps1", nil) + if err != nil { + t.Fatalf("expected powershell provision CA to succeed: %s", err.Error()) + } + // Fire up the server srvConfig := createConfFile(t, []byte(` listen: "localhost:-1" @@ -191,6 +217,7 @@ func TestServerTLSWindowsCertStore(t *testing.T) { cert_store: "WindowsCurrentUser" cert_match_by: "Subject" cert_match: "NATS.io Operators" + ca_certs_match: ["NATS CA"] timeout: 5 } `)) @@ -228,3 +255,55 @@ func TestServerTLSWindowsCertStore(t *testing.T) { }) } } + +// TestServerIgnoreExpiredCerts tests if the server skips expired certificates in configuration, and finds non-expired ones +func TestServerIgnoreExpiredCerts(t *testing.T) { + + // Server Identities: expired.pem; not-expired.pem + // Issuer: OU = NATS.io, CN = localhost + // Subject: OU = NATS.io Operators, CN = localhost + + testCases := []struct { + certFile string + expect bool + }{ + {"expired.p12", false}, + {"not-expired.p12", true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("Server certificate: %s", tc.certFile), func(t *testing.T) { + // Make sure windows cert store is reset to avoid conflict with other tests + err := runPowershellScript("../test/configs/certs/tlsauth/certstore/delete-cert-from-store.ps1", nil) + if err != nil { + t.Fatalf("expected powershell cert delete to succeed: %s", err.Error()) + } + + // Provision Windows cert store with server cert and secret + err = runPowershellScript("../test/configs/certs/tlsauth/certstore/import-p12-server.ps1", []string{tc.certFile}) + if err != nil { + t.Fatalf("expected powershell provision to succeed: %s", err.Error()) + } + // Fire up the server + srvConfig := createConfFile(t, []byte(` + listen: "localhost:-1" + tls { + cert_store: "WindowsCurrentUser" + cert_match_by: "Subject" + cert_match: "NATS.io Operators" + cert_match_skip_invalid: true + timeout: 5 + } + `)) + defer removeFile(t, srvConfig) + cfg, _ := ProcessConfigFile(srvConfig) + if (cfg != nil) == tc.expect { + return + } + if tc.expect == false { + t.Fatalf("expected server start to fail with expired certificate") + } else { + t.Fatalf("expected server to start with non expired certificate") + } + }) + } +} diff --git a/server/client.go b/server/client.go index f9ddbebc2d9..584c90aedca 100644 --- a/server/client.go +++ b/server/client.go @@ -4608,12 +4608,18 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, ql := _ql[:0] for i := 0; i < len(qsubs); i++ { sub = qsubs[i] - if sub.client.kind == LEAF || sub.client.kind == ROUTER { - // If we have assigned an rsub already, replace if the destination is a LEAF - // since we want to favor that compared to a ROUTER. We could make sure that - // we override only if previous was a ROUTE and not a LEAF, but we don't have to. - if rsub == nil || sub.client.kind == LEAF { + if dst := sub.client.kind; dst == LEAF || dst == ROUTER { + // If we have assigned an ROUTER rsub already, replace if + // the destination is a LEAF since we want to favor that. + if rsub == nil || (rsub.client.kind == ROUTER && dst == LEAF) { rsub = sub + } else if dst == LEAF { + // We already have a LEAF and this is another one. + // Flip a coin to see if we swap it or not. + // See https://github.com/nats-io/nats-server/issues/6040 + if fastrand.Uint32()%2 == 1 { + rsub = sub + } } } else { ql = append(ql, sub) diff --git a/server/consumer.go b/server/consumer.go index 849fb1c5362..a336c544dcf 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -3011,6 +3011,14 @@ func (o *consumer) needAck(sseq uint64, subj string) bool { return needAck } +// Used in nextReqFromMsg, since the json.Unmarshal causes the request +// struct to escape to the heap always. This should reduce GC pressure. +var jsGetNextPool = sync.Pool{ + New: func() any { + return &JSApiConsumerGetNextRequest{} + }, +} + // Helper for the next message requests. func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time.Time, error) { req := bytes.TrimSpace(msg) @@ -3020,7 +3028,11 @@ func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time. return time.Time{}, 1, 0, false, 0, time.Time{}, nil case req[0] == '{': - var cr JSApiConsumerGetNextRequest + cr := jsGetNextPool.Get().(*JSApiConsumerGetNextRequest) + defer func() { + *cr = JSApiConsumerGetNextRequest{} + jsGetNextPool.Put(cr) + }() if err := json.Unmarshal(req, &cr); err != nil { return time.Time{}, -1, 0, false, 0, time.Time{}, err } @@ -3420,6 +3432,7 @@ func (o *consumer) processNextMsgRequest(reply string, msg []byte) { if err := o.waiting.add(wr); err != nil { sendErr(409, "Exceeded MaxWaiting") + wr.recycle() return } o.signalNewMessages() diff --git a/server/filestore.go b/server/filestore.go index ec66ad28f2c..2208cd7859b 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -1739,6 +1739,7 @@ func (fs *fileStore) recoverFullState() (rerr error) { var matched bool mb := fs.lmb if mb == nil || mb.index != blkIndex { + os.Remove(fn) fs.warn("Stream state block does not exist or index mismatch") return errCorruptState } @@ -1777,6 +1778,14 @@ func (fs *fileStore) recoverFullState() (rerr error) { } } + // We check first and last seq and number of msgs and bytes. If there is a difference, + // return and error so we rebuild from the message block state on disk. + if !trackingStatesEqual(&fs.state, &mstate) { + os.Remove(fn) + fs.warn("Stream state encountered internal inconsistency on recover") + return errCorruptState + } + return nil } @@ -8037,7 +8046,7 @@ func (fs *fileStore) _writeFullState(force bool) error { // Snapshot prior dirty count. priorDirty := fs.dirty - statesEqual := trackingStatesEqual(&fs.state, &mstate) || len(fs.blks) > 0 + statesEqual := trackingStatesEqual(&fs.state, &mstate) // Release lock. fs.mu.Unlock() diff --git a/server/filestore_test.go b/server/filestore_test.go index 1f7a4b403e0..4f409d2e5af 100644 --- a/server/filestore_test.go +++ b/server/filestore_test.go @@ -22,6 +22,7 @@ import ( "crypto/hmac" crand "crypto/rand" "crypto/sha256" + "encoding/binary" "encoding/hex" "encoding/json" "errors" @@ -8076,3 +8077,70 @@ func Benchmark_FileStoreCreateConsumerStores(b *testing.B) { }) } } + +func TestFileStoreWriteFullStateDetectCorruptState(t *testing.T) { + fs, err := newFileStore( + FileStoreConfig{StoreDir: t.TempDir()}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage}) + require_NoError(t, err) + defer fs.Stop() + + msg := []byte("abc") + for i := 1; i <= 10; i++ { + _, _, err = fs.StoreMsg(fmt.Sprintf("foo.%d", i), nil, msg) + require_NoError(t, err) + } + + // Simulate a change in a message block not being reflected in the fs. + mb := fs.selectMsgBlock(2) + mb.mu.Lock() + mb.msgs-- + mb.mu.Unlock() + + var ss StreamState + fs.FastState(&ss) + require_Equal(t, ss.FirstSeq, 1) + require_Equal(t, ss.LastSeq, 10) + require_Equal(t, ss.Msgs, 10) + + // Make sure we detect the corrupt state and rebuild. + err = fs.writeFullState() + require_Error(t, err, errCorruptState) + + fs.FastState(&ss) + require_Equal(t, ss.FirstSeq, 1) + require_Equal(t, ss.LastSeq, 10) + require_Equal(t, ss.Msgs, 9) +} + +func TestFileStoreRecoverFullStateDetectCorruptState(t *testing.T) { + fs, err := newFileStore( + FileStoreConfig{StoreDir: t.TempDir()}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*"}, Storage: FileStorage}) + require_NoError(t, err) + defer fs.Stop() + + msg := []byte("abc") + for i := 1; i <= 10; i++ { + _, _, err = fs.StoreMsg(fmt.Sprintf("foo.%d", i), nil, msg) + require_NoError(t, err) + } + + err = fs.writeFullState() + require_NoError(t, err) + + sfile := filepath.Join(fs.fcfg.StoreDir, msgDir, streamStreamStateFile) + buf, err := os.ReadFile(sfile) + require_NoError(t, err) + // Update to an incorrect message count. + binary.PutUvarint(buf[2:], 0) + // Just append a corrected checksum to the end to make it pass the checks. + fs.hh.Reset() + fs.hh.Write(buf) + buf = fs.hh.Sum(buf) + err = os.WriteFile(sfile, buf, defaultFilePerms) + require_NoError(t, err) + + err = fs.recoverFullState() + require_Error(t, err, errCorruptState) +} diff --git a/server/jetstream_cluster.go b/server/jetstream_cluster.go index 9d7fc0550da..ca6763b9cbf 100644 --- a/server/jetstream_cluster.go +++ b/server/jetstream_cluster.go @@ -7655,54 +7655,46 @@ const compressThreshold = 8192 // 8k // If allowed and contents over the threshold we will compress. func encodeStreamMsgAllowCompress(subject, reply string, hdr, msg []byte, lseq uint64, ts int64, compressOK bool) []byte { - shouldCompress := compressOK && len(subject)+len(reply)+len(hdr)+len(msg) > compressThreshold - - elen := 1 + 8 + 8 + len(subject) + len(reply) + len(hdr) + len(msg) + // Clip the subject, reply, header and msgs down. Operate on + // uint64 lengths to avoid overflowing. + slen := min(uint64(len(subject)), math.MaxUint16) + rlen := min(uint64(len(reply)), math.MaxUint16) + hlen := min(uint64(len(hdr)), math.MaxUint16) + mlen := min(uint64(len(msg)), math.MaxUint32) + total := slen + rlen + hlen + mlen + + shouldCompress := compressOK && total > compressThreshold + elen := int(1 + 8 + 8 + total) elen += (2 + 2 + 2 + 4) // Encoded lengths, 4bytes - // TODO(dlc) - check sizes of subject, reply and hdr, make sure uint16 ok. - buf := make([]byte, elen) + + buf := make([]byte, 1, elen) buf[0] = byte(streamMsgOp) + var le = binary.LittleEndian - wi := 1 - le.PutUint64(buf[wi:], lseq) - wi += 8 - le.PutUint64(buf[wi:], uint64(ts)) - wi += 8 - le.PutUint16(buf[wi:], uint16(len(subject))) - wi += 2 - copy(buf[wi:], subject) - wi += len(subject) - le.PutUint16(buf[wi:], uint16(len(reply))) - wi += 2 - copy(buf[wi:], reply) - wi += len(reply) - le.PutUint16(buf[wi:], uint16(len(hdr))) - wi += 2 - if len(hdr) > 0 { - copy(buf[wi:], hdr) - wi += len(hdr) - } - le.PutUint32(buf[wi:], uint32(len(msg))) - wi += 4 - if len(msg) > 0 { - copy(buf[wi:], msg) - wi += len(msg) - } + buf = le.AppendUint64(buf, lseq) + buf = le.AppendUint64(buf, uint64(ts)) + buf = le.AppendUint16(buf, uint16(slen)) + buf = append(buf, subject[:slen]...) + buf = le.AppendUint16(buf, uint16(rlen)) + buf = append(buf, reply[:rlen]...) + buf = le.AppendUint16(buf, uint16(hlen)) + buf = append(buf, hdr[:hlen]...) + buf = le.AppendUint32(buf, uint32(mlen)) + buf = append(buf, msg[:mlen]...) // Check if we should compress. if shouldCompress { nbuf := make([]byte, s2.MaxEncodedLen(elen)) nbuf[0] = byte(compressedStreamMsgOp) - ebuf := s2.Encode(nbuf[1:], buf[1:wi]) - // Only pay cost of decode the other side if we compressed. + ebuf := s2.Encode(nbuf[1:], buf[1:]) + // Only pay the cost of decode on the other side if we compressed. // S2 will allow us to try without major penalty for non-compressable data. - if len(ebuf) < wi { - nbuf = nbuf[:len(ebuf)+1] - buf, wi = nbuf, len(nbuf) + if len(ebuf) < len(buf) { + buf = nbuf[:len(ebuf)+1] } } - return buf[:wi] + return buf } // Determine if all peers in our set support the binary snapshot. @@ -8290,7 +8282,16 @@ RETRY: releaseSyncOutSem() if n.GroupLeader() == _EMPTY_ { - return fmt.Errorf("%w for stream '%s > %s'", errCatchupAbortedNoLeader, mset.account(), mset.name()) + // Prevent us from spinning if we've installed a snapshot from a leader but there's no leader online. + // We wait a bit to check if a leader has come online in the meantime, if so we can continue. + var canContinue bool + if numRetries == 0 { + time.Sleep(startInterval) + canContinue = n.GroupLeader() != _EMPTY_ + } + if !canContinue { + return fmt.Errorf("%w for stream '%s > %s'", errCatchupAbortedNoLeader, mset.account(), mset.name()) + } } // If we have a sub clear that here. diff --git a/server/jetstream_cluster_4_test.go b/server/jetstream_cluster_4_test.go index c74ba05a3e6..6a4e24a02b0 100644 --- a/server/jetstream_cluster_4_test.go +++ b/server/jetstream_cluster_4_test.go @@ -23,6 +23,7 @@ import ( "fmt" "math/rand" "os" + "path" "path/filepath" "runtime" "slices" @@ -3814,3 +3815,79 @@ func TestJetStreamClusterDesyncAfterErrorDuringCatchup(t *testing.T) { }) } } + +func TestJetStreamClusterDesyncAfterRestartReplacesLeaderSnapshot(t *testing.T) { + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + Replicas: 3, + }) + require_NoError(t, err) + + // Reconnect to the leader. + leader := c.streamLeader(globalAccountName, "TEST") + nc.Close() + + nc, js = jsClientConnect(t, leader) + defer nc.Close() + + lookupStream := func(s *Server) *stream { + t.Helper() + acc, err := s.lookupAccount(globalAccountName) + require_NoError(t, err) + mset, err := acc.lookupStream("TEST") + require_NoError(t, err) + return mset + } + + // Stop one follower so it lags behind. + rs := c.randomNonStreamLeader(globalAccountName, "TEST") + mset := lookupStream(rs) + n := mset.node.(*raft) + followerSnapshots := path.Join(n.sd, snapshotsDir) + rs.Shutdown() + rs.WaitForShutdown() + + // Move the stream forward so the follower requires a snapshot. + err = js.PurgeStream("TEST", &nats.StreamPurgeRequest{Sequence: 10}) + require_NoError(t, err) + _, err = js.Publish("foo", nil) + require_NoError(t, err) + + // Install a snapshot on the leader, ensuring RAFT entries are compacted and a snapshot remains. + mset = lookupStream(leader) + n = mset.node.(*raft) + err = n.InstallSnapshot(mset.stateSnapshot()) + require_NoError(t, err) + + c.stopAll() + + // Replace follower snapshot with the leader's. + // This simulates the follower coming online, getting a snapshot from the leader after which it goes offline. + leaderSnapshots := path.Join(n.sd, snapshotsDir) + err = os.RemoveAll(followerSnapshots) + require_NoError(t, err) + err = copyDir(t, followerSnapshots, leaderSnapshots) + require_NoError(t, err) + + // Start the follower, it will load the snapshot from the leader. + rs = c.restartServer(rs) + + // Shutting down must check that the leader's snapshot is not overwritten. + rs.Shutdown() + rs.WaitForShutdown() + + // Now start all servers back up. + c.restartAll() + c.waitOnAllCurrent() + + checkFor(t, 10*time.Second, 500*time.Millisecond, func() error { + return checkState(t, c, globalAccountName, "TEST") + }) +} diff --git a/server/jetstream_helpers_test.go b/server/jetstream_helpers_test.go index da2814e1452..eb21057a04c 100644 --- a/server/jetstream_helpers_test.go +++ b/server/jetstream_helpers_test.go @@ -19,11 +19,15 @@ package server import ( "context" "encoding/json" + "errors" "fmt" + "io" + "io/fs" "math/rand" "net" "net/url" "os" + "path" "strings" "sync" "testing" @@ -1867,3 +1871,109 @@ func (b *bitset) String() string { sb.WriteString("\n") return sb.String() } + +func copyDir(t *testing.T, dst, src string) error { + t.Helper() + srcFS := os.DirFS(src) + return fs.WalkDir(srcFS, ".", func(p string, d os.DirEntry, err error) error { + if err != nil { + return err + } + newPath := path.Join(dst, p) + if d.IsDir() { + return os.MkdirAll(newPath, defaultDirPerms) + } + r, err := srcFS.Open(p) + if err != nil { + return err + } + defer r.Close() + + w, err := os.OpenFile(newPath, os.O_CREATE|os.O_WRONLY, defaultFilePerms) + if err != nil { + return err + } + defer w.Close() + _, err = io.Copy(w, r) + return err + }) +} + +func getStreamDetails(t *testing.T, c *cluster, accountName, streamName string) *StreamDetail { + t.Helper() + srv := c.streamLeader(accountName, streamName) + if srv == nil { + return nil + } + jsz, err := srv.Jsz(&JSzOptions{Accounts: true, Streams: true, Consumer: true}) + require_NoError(t, err) + for _, acc := range jsz.AccountDetails { + if acc.Name == accountName { + for _, stream := range acc.Streams { + if stream.Name == streamName { + return &stream + } + } + } + } + t.Error("Could not find account details") + return nil +} + +func checkState(t *testing.T, c *cluster, accountName, streamName string) error { + t.Helper() + + leaderSrv := c.streamLeader(accountName, streamName) + if leaderSrv == nil { + return fmt.Errorf("no leader server found for stream %q", streamName) + } + streamLeader := getStreamDetails(t, c, accountName, streamName) + if streamLeader == nil { + return fmt.Errorf("no leader found for stream %q", streamName) + } + var errs []error + for _, srv := range c.servers { + if srv == leaderSrv { + // Skip self + continue + } + acc, err := srv.LookupAccount(accountName) + require_NoError(t, err) + stream, err := acc.lookupStream(streamName) + require_NoError(t, err) + state := stream.state() + + if state.Msgs != streamLeader.State.Msgs { + err := fmt.Errorf("[%s] Leader %v has %d messages, Follower %v has %d messages", + streamName, leaderSrv, streamLeader.State.Msgs, + srv, state.Msgs, + ) + errs = append(errs, err) + } + if state.FirstSeq != streamLeader.State.FirstSeq { + err := fmt.Errorf("[%s] Leader %v FirstSeq is %d, Follower %v is at %d", + streamName, leaderSrv, streamLeader.State.FirstSeq, + srv, state.FirstSeq, + ) + errs = append(errs, err) + } + if state.LastSeq != streamLeader.State.LastSeq { + err := fmt.Errorf("[%s] Leader %v LastSeq is %d, Follower %v is at %d", + streamName, leaderSrv, streamLeader.State.LastSeq, + srv, state.LastSeq, + ) + errs = append(errs, err) + } + if state.NumDeleted != streamLeader.State.NumDeleted { + err := fmt.Errorf("[%s] Leader %v NumDeleted is %d, Follower %v is at %d\nSTATE_A: %+v\nSTATE_B: %+v\n", + streamName, leaderSrv, streamLeader.State.NumDeleted, + srv, state.NumDeleted, streamLeader.State, state, + ) + errs = append(errs, err) + } + } + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} diff --git a/server/leafnode_test.go b/server/leafnode_test.go index 07577487857..3ca25b484d0 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -4117,6 +4117,122 @@ func TestLeafNodeQueueGroupDistribution(t *testing.T) { sendAndCheck(2) } +func TestLeafNodeQueueGroupDistributionVariant(t *testing.T) { + hc := createClusterWithName(t, "HUB", 3) + defer hc.shutdown() + + // Now have a cluster of leafnodes with LEAF1 and LEAF2 connecting to HUB1. + c1 := ` + server_name: LEAF1 + listen: 127.0.0.1:-1 + cluster { name: ln22, listen: 127.0.0.1:-1 } + leafnodes { remotes = [{ url: nats-leaf://127.0.0.1:%d }] } + ` + lconf1 := createConfFile(t, []byte(fmt.Sprintf(c1, hc.opts[0].LeafNode.Port))) + ln1, lopts1 := RunServerWithConfig(lconf1) + defer ln1.Shutdown() + + c2 := ` + server_name: LEAF2 + listen: 127.0.0.1:-1 + cluster { name: ln22, listen: 127.0.0.1:-1, routes = [ nats-route://127.0.0.1:%d] } + leafnodes { remotes = [{ url: nats-leaf://127.0.0.1:%d }] } + ` + lconf2 := createConfFile(t, []byte(fmt.Sprintf(c2, lopts1.Cluster.Port, hc.opts[0].LeafNode.Port))) + ln2, _ := RunServerWithConfig(lconf2) + defer ln2.Shutdown() + + // And LEAF3 to HUB3 + c3 := ` + server_name: LEAF3 + listen: 127.0.0.1:-1 + cluster { name: ln22, listen: 127.0.0.1:-1, routes = [ nats-route://127.0.0.1:%d] } + leafnodes { remotes = [{ url: nats-leaf://127.0.0.1:%d }] } + ` + lconf3 := createConfFile(t, []byte(fmt.Sprintf(c3, lopts1.Cluster.Port, hc.opts[2].LeafNode.Port))) + ln3, _ := RunServerWithConfig(lconf3) + defer ln3.Shutdown() + + // Check leaf cluster is formed and all connected to the HUB. + lnServers := []*Server{ln1, ln2, ln3} + checkClusterFormed(t, lnServers...) + for _, s := range lnServers { + checkLeafNodeConnected(t, s) + } + // Check that HUB1 has 2 leaf connections, HUB2 has 0 and HUB3 has 1. + checkLeafNodeConnectedCount(t, hc.servers[0], 2) + checkLeafNodeConnectedCount(t, hc.servers[1], 0) + checkLeafNodeConnectedCount(t, hc.servers[2], 1) + + // Create a client and qsub on LEAF1 and LEAF2. + nc1 := natsConnect(t, ln1.ClientURL()) + defer nc1.Close() + var qsub1Count atomic.Int32 + natsQueueSub(t, nc1, "foo", "queue1", func(_ *nats.Msg) { + qsub1Count.Add(1) + }) + natsFlush(t, nc1) + + nc2 := natsConnect(t, ln2.ClientURL()) + defer nc2.Close() + var qsub2Count atomic.Int32 + natsQueueSub(t, nc2, "foo", "queue1", func(_ *nats.Msg) { + qsub2Count.Add(1) + }) + natsFlush(t, nc2) + + // Make sure that the propagation interest is done before sending. + for i, s := range hc.servers { + gacc := s.GlobalAccount() + var ei int + switch i { + case 0: + ei = 2 + default: + ei = 1 + } + checkFor(t, time.Second, 15*time.Millisecond, func() error { + if n := gacc.Interest("foo"); n != ei { + return fmt.Errorf("Expected interest for %q to be %d, got %v", "foo", ei, n) + } + return nil + }) + } + + sendAndCheck := func(idx int) { + t.Helper() + nchub := natsConnect(t, hc.servers[idx].ClientURL()) + defer nchub.Close() + total := 1000 + for i := 0; i < total; i++ { + natsPub(t, nchub, "foo", []byte("from hub")) + } + checkFor(t, time.Second, 15*time.Millisecond, func() error { + if trecv := int(qsub1Count.Load() + qsub2Count.Load()); trecv != total { + return fmt.Errorf("Expected %v messages, got %v", total, trecv) + } + return nil + }) + // Now that we have made sure that all messages were received, + // check that qsub1 and qsub2 are getting at least some. + if n := int(qsub1Count.Load()); n <= total/10 { + t.Fatalf("Expected qsub1 to get some messages, but got %v (qsub2=%v)", n, qsub2Count.Load()) + } + if n := int(qsub2Count.Load()); n <= total/10 { + t.Fatalf("Expected qsub2 to get some messages, but got %v (qsub1=%v)", n, qsub1Count.Load()) + } + // Reset the counters. + qsub1Count.Store(0) + qsub2Count.Store(0) + } + // Send from HUB1 + sendAndCheck(0) + // Send from HUB2 + sendAndCheck(1) + // Send from HUB3 + sendAndCheck(2) +} + func TestLeafNodeQueueGroupWithLateLNJoin(t *testing.T) { /* diff --git a/server/opts.go b/server/opts.go index 0b4ed483dce..c73127e5309 100644 --- a/server/opts.go +++ b/server/opts.go @@ -1,4 +1,4 @@ -// Copyright 2012-2023 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -657,26 +657,28 @@ type authorization struct { // TLSConfigOpts holds the parsed tls config information, // used with flag parsing type TLSConfigOpts struct { - CertFile string - KeyFile string - CaFile string - Verify bool - Insecure bool - Map bool - TLSCheckKnownURLs bool - HandshakeFirst bool // Indicate that the TLS handshake should occur first, before sending the INFO protocol. - FallbackDelay time.Duration // Where supported, indicates how long to wait for the handshake before falling back to sending the INFO protocol first. - Timeout float64 - RateLimit int64 - Ciphers []uint16 - CurvePreferences []tls.CurveID - PinnedCerts PinnedCertSet - CertStore certstore.StoreType - CertMatchBy certstore.MatchByType - CertMatch string - OCSPPeerConfig *certidp.OCSPPeerConfig - Certificates []*TLSCertPairOpt - MinVersion uint16 + CertFile string + KeyFile string + CaFile string + Verify bool + Insecure bool + Map bool + TLSCheckKnownURLs bool + HandshakeFirst bool // Indicate that the TLS handshake should occur first, before sending the INFO protocol. + FallbackDelay time.Duration // Where supported, indicates how long to wait for the handshake before falling back to sending the INFO protocol first. + Timeout float64 + RateLimit int64 + Ciphers []uint16 + CurvePreferences []tls.CurveID + PinnedCerts PinnedCertSet + CertStore certstore.StoreType + CertMatchBy certstore.MatchByType + CertMatch string + CertMatchSkipInvalid bool + CaCertsMatch []string + OCSPPeerConfig *certidp.OCSPPeerConfig + Certificates []*TLSCertPairOpt + MinVersion uint16 } // TLSCertPairOpt are the paths to a certificate and private key. @@ -4419,6 +4421,28 @@ func parseTLS(v any, isClientCtx bool) (t *TLSConfigOpts, retErr error) { return nil, &configErr{tk, certstore.ErrBadCertMatchField.Error()} } tc.CertMatch = certMatch + case "ca_certs_match": + rv := []string{} + switch mv := mv.(type) { + case string: + rv = append(rv, mv) + case []string: + rv = append(rv, mv...) + case []interface{}: + for _, t := range mv { + if token, ok := t.(token); ok { + if ts, ok := token.Value().(string); ok { + rv = append(rv, ts) + continue + } else { + return nil, &configErr{tk, fmt.Sprintf("error parsing ca_cert_match: unsupported type %T where string is expected", token)} + } + } else { + return nil, &configErr{tk, fmt.Sprintf("error parsing ca_cert_match: unsupported type %T", t)} + } + } + } + tc.CaCertsMatch = rv case "handshake_first", "first", "immediate": switch mv := mv.(type) { case bool: @@ -4444,6 +4468,12 @@ func parseTLS(v any, isClientCtx bool) (t *TLSConfigOpts, retErr error) { default: return nil, &configErr{tk, fmt.Sprintf("field %q should be a boolean or a string, got %T", mk, mv)} } + case "cert_match_skip_invalid": + certMatchSkipInvalid, ok := mv.(bool) + if !ok { + return nil, &configErr{tk, certstore.ErrBadCertMatchSkipInvalidField.Error()} + } + tc.CertMatchSkipInvalid = certMatchSkipInvalid case "ocsp_peer": switch vv := mv.(type) { case bool: @@ -4819,7 +4849,7 @@ func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { } config.Certificates = []tls.Certificate{cert} case tc.CertStore != certstore.STOREEMPTY: - err := certstore.TLSConfig(tc.CertStore, tc.CertMatchBy, tc.CertMatch, &config) + err := certstore.TLSConfig(tc.CertStore, tc.CertMatchBy, tc.CertMatch, tc.CaCertsMatch, tc.CertMatchSkipInvalid, &config) if err != nil { return nil, err } diff --git a/server/raft.go b/server/raft.go index cd8d2d11589..5397296d2f1 100644 --- a/server/raft.go +++ b/server/raft.go @@ -1026,7 +1026,7 @@ func (n *raft) InstallSnapshot(data []byte) error { // Check that a catchup isn't already taking place. If it is then we won't // allow installing snapshots until it is done. - if len(n.progress) > 0 { + if len(n.progress) > 0 || n.paused { return errCatchupsRunning } diff --git a/server/stream.go b/server/stream.go index bfc75b3c1c0..7aaf4e6a6fc 100644 --- a/server/stream.go +++ b/server/stream.go @@ -4819,6 +4819,9 @@ func newJSPubMsg(dsubj, subj, reply string, hdr, msg []byte, o *consumer, seq ui if pm != nil { m = pm.(*jsPubMsg) buf = m.buf[:0] + if hdr != nil { + hdr = append(m.hdr[:0], hdr...) + } } else { m = new(jsPubMsg) } @@ -4847,6 +4850,9 @@ func (pm *jsPubMsg) returnToPool() { if len(pm.buf) > 0 { pm.buf = pm.buf[:0] } + if len(pm.hdr) > 0 { + pm.hdr = pm.hdr[:0] + } jsPubMsgPool.Put(pm) } diff --git a/server/stree/stree.go b/server/stree/stree.go index a289a629742..828631888f9 100644 --- a/server/stree/stree.go +++ b/server/stree/stree.go @@ -283,7 +283,7 @@ func (t *SubjectTree[T]) delete(np *node, subject []byte, si int) (*T, bool) { func (t *SubjectTree[T]) match(n node, parts [][]byte, pre []byte, cb func(subject []byte, val *T)) { // Capture if we are sitting on a terminal fwc. var hasFWC bool - if lp := len(parts); lp > 0 && parts[lp-1][0] == fwc { + if lp := len(parts); lp > 0 && len(parts[lp-1]) > 0 && parts[lp-1][0] == fwc { hasFWC = true } diff --git a/server/stree/stree_test.go b/server/stree/stree_test.go index 2a84d3c7598..cf6d08512b9 100644 --- a/server/stree/stree_test.go +++ b/server/stree/stree_test.go @@ -842,3 +842,15 @@ func TestSubjectTreeInsertWithNoPivot(t *testing.T) { require_False(t, updated) require_Equal(t, st.Size(), 0) } + +// Make sure we don't panic when checking for fwc. +func TestSubjectTreeMatchHasFWCNoPanic(t *testing.T) { + defer func() { + p := recover() + require_True(t, p == nil) + }() + st := NewSubjectTree[int]() + subj := []byte("foo") + st.Insert(subj, 1) + st.Match([]byte("."), func(subject []byte, val *int) {}) +} diff --git a/test/configs/certs/tlsauth/certstore/ca.p12 b/test/configs/certs/tlsauth/certstore/ca.p12 new file mode 100644 index 00000000000..d9b67effe0f Binary files /dev/null and b/test/configs/certs/tlsauth/certstore/ca.p12 differ diff --git a/test/configs/certs/tlsauth/certstore/client.p12 b/test/configs/certs/tlsauth/certstore/client.p12 index 18ee5c32f01..3c2e93a93ac 100644 Binary files a/test/configs/certs/tlsauth/certstore/client.p12 and b/test/configs/certs/tlsauth/certstore/client.p12 differ diff --git a/test/configs/certs/tlsauth/certstore/delete-cert-from-store.ps1 b/test/configs/certs/tlsauth/certstore/delete-cert-from-store.ps1 index e3a05ae58fd..3f2eccd7327 100644 --- a/test/configs/certs/tlsauth/certstore/delete-cert-from-store.ps1 +++ b/test/configs/certs/tlsauth/certstore/delete-cert-from-store.ps1 @@ -1,2 +1,5 @@ -$issuer="Synadia Communications Inc." -Get-ChildItem Cert:\CurrentUser\My | Where-Object {$_.Issuer -match $issuer} | Remove-Item \ No newline at end of file +$issuer="NATS CA" +Get-ChildItem Cert:\CurrentUser\My | Where-Object {$_.Issuer -match $issuer} | Remove-Item +Get-ChildItem Cert:\CurrentUser\CA| Where-Object {$_.Issuer -match $issuer} | Remove-Item +Get-ChildItem Cert:\CurrentUser\AuthRoot | Where-Object {$_.Issuer -match $issuer} | Remove-Item +Get-ChildItem Cert:\CurrentUser\Root | Where-Object {$_.Issuer -match $issuer} | Remove-Item diff --git a/test/configs/certs/tlsauth/certstore/expired.p12 b/test/configs/certs/tlsauth/certstore/expired.p12 new file mode 100644 index 00000000000..5870bbc3d00 Binary files /dev/null and b/test/configs/certs/tlsauth/certstore/expired.p12 differ diff --git a/test/configs/certs/tlsauth/certstore/import-p12-ca.ps1 b/test/configs/certs/tlsauth/certstore/import-p12-ca.ps1 new file mode 100644 index 00000000000..888816f8a70 --- /dev/null +++ b/test/configs/certs/tlsauth/certstore/import-p12-ca.ps1 @@ -0,0 +1,7 @@ +$fileLocale = $PSScriptRoot + "\ca.p12" +$Pass = ConvertTo-SecureString -String 's3cr3t' -Force -AsPlainText +$User = "whatever" +$Cred = New-Object -TypeName "System.Management.Automation.PSCredential" -ArgumentList $User, $Pass +Import-PfxCertificate -FilePath $filelocale -CertStoreLocation Cert:\CurrentUser\My -Password $Cred.Password +#Import-PfxCertificate -FilePath $filelocale -CertStoreLocation Cert:\LocalMachine\Root -Password $Cred.Password +# TODO? Move to trusted enterprise? Requires some fingerprint parsing. \ No newline at end of file diff --git a/test/configs/certs/tlsauth/certstore/import-p12-server.ps1 b/test/configs/certs/tlsauth/certstore/import-p12-server.ps1 index 14a006d8249..006a1e4a513 100644 --- a/test/configs/certs/tlsauth/certstore/import-p12-server.ps1 +++ b/test/configs/certs/tlsauth/certstore/import-p12-server.ps1 @@ -1,4 +1,7 @@ -$fileLocale = $PSScriptRoot + "\server.p12" +$file=$args[0] +if (!$file) { $file="server.p12 "} +$fileLocale = $PSScriptRoot + "\" + $file +echo "Installing certificate $fileLocale" $Pass = ConvertTo-SecureString -String 's3cr3t' -Force -AsPlainText $User = "whatever" $Cred = New-Object -TypeName "System.Management.Automation.PSCredential" -ArgumentList $User, $Pass diff --git a/test/configs/certs/tlsauth/certstore/not-expired.p12 b/test/configs/certs/tlsauth/certstore/not-expired.p12 new file mode 100644 index 00000000000..d99cb55a946 Binary files /dev/null and b/test/configs/certs/tlsauth/certstore/not-expired.p12 differ diff --git a/test/configs/certs/tlsauth/certstore/pkcs12.md b/test/configs/certs/tlsauth/certstore/pkcs12.md index 569d88084af..b516ab113fe 100644 --- a/test/configs/certs/tlsauth/certstore/pkcs12.md +++ b/test/configs/certs/tlsauth/certstore/pkcs12.md @@ -6,6 +6,10 @@ Refresh PKCS12 files when test certificates and keys (PEM files) are refreshed ( `openssl pkcs12 -export -inkey ./client-key.pem -in ./client.pem -out client.p12` +To add the CA, use the following: + +`openssl pkcs12 -export -nokeys -in ..\ca.pem -out ca.p12` + > Note: set the PKCS12 bundle password to `s3cr3t` as required by provisioning scripts ## Cert Store Provisioning Scripts diff --git a/test/configs/certs/tlsauth/certstore/server.p12 b/test/configs/certs/tlsauth/certstore/server.p12 index 9325afbc15c..f7a1c4d0135 100644 Binary files a/test/configs/certs/tlsauth/certstore/server.p12 and b/test/configs/certs/tlsauth/certstore/server.p12 differ