diff --git a/server/filestore.go b/server/filestore.go index 0a04a7c4913..3d14b9180ba 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -3332,11 +3332,11 @@ func (fs *fileStore) NumPendingMulti(sseq uint64, sl *Sublist, lastPerSubject bo var t uint64 var havePartial bool - mb.fss.Iter(func(bsubj []byte, ss *SimpleState) bool { + IntersectStree[SimpleState](mb.fss, sl, func(bsubj []byte, ss *SimpleState) { subj := bytesToString(bsubj) - if havePartial || !sl.HasInterest(subj) { + if havePartial { // If we already found a partial then don't do anything else. - return !havePartial + return } if ss.firstNeedsUpdate { mb.recalculateFirstForSubj(subj, ss.First, ss) @@ -3347,7 +3347,6 @@ func (fs *fileStore) NumPendingMulti(sseq uint64, sl *Sublist, lastPerSubject bo // We matched but its a partial. havePartial = true } - return !havePartial }) // See if we need to scan msgs here. @@ -3432,11 +3431,8 @@ func (fs *fileStore) NumPendingMulti(sseq uint64, sl *Sublist, lastPerSubject bo } // Mark fss activity. mb.lsts = time.Now().UnixNano() - mb.fss.Iter(func(bsubj []byte, ss *SimpleState) bool { - if sl.HasInterest(bytesToString(bsubj)) { - adjust += ss.Msgs - } - return true + IntersectStree(mb.fss, sl, func(bsubj []byte, ss *SimpleState) { + adjust += ss.Msgs }) } } else { diff --git a/server/memstore.go b/server/memstore.go index 55e98882195..ee953e1fa99 100644 --- a/server/memstore.go +++ b/server/memstore.go @@ -719,10 +719,7 @@ func (ms *memStore) NumPendingMulti(sseq uint64, sl *Sublist, lastPerSubject boo var havePartial bool var totalSkipped uint64 // We will track start and end sequences as we go. - ms.fss.Iter(func(subj []byte, fss *SimpleState) bool { - if !sl.HasInterest(bytesToString(subj)) { - return true - } + IntersectStree[SimpleState](ms.fss, sl, func(subj []byte, fss *SimpleState) { if fss.firstNeedsUpdate { ms.recalculateFirstForSubj(bytesToString(subj), fss.First, fss) } @@ -736,7 +733,6 @@ func (ms *memStore) NumPendingMulti(sseq uint64, sl *Sublist, lastPerSubject boo } else { totalSkipped += fss.Msgs } - return true }) // If we did not encounter any partials we can return here. diff --git a/server/sublist.go b/server/sublist.go index 657666f1afb..d34bcde1ad1 100644 --- a/server/sublist.go +++ b/server/sublist.go @@ -20,6 +20,8 @@ import ( "sync" "sync/atomic" "unicode/utf8" + + "github.com/nats-io/nats-server/v2/server/stree" ) // Sublist is a routing mechanism to handle subject distribution and @@ -1735,3 +1737,44 @@ func getAllNodes(l *level, results *SublistResult) { getAllNodes(n.next, results) } } + +// IntersectStree will match all items in the given subject tree that +// have interest expressed in the given sublist. The callback will only be called +// once for each subject, regardless of overlapping subscriptions in the sublist. +func IntersectStree[T any](st *stree.SubjectTree[T], sl *Sublist, cb func(subj []byte, entry *T)) { + var _subj [255]byte + intersectStree(st, sl.root, _subj[:0], cb) +} + +func intersectStree[T any](st *stree.SubjectTree[T], r *level, subj []byte, cb func(subj []byte, entry *T)) { + if r.numNodes() == 0 { + st.Match(subj, cb) + return + } + nsubj := subj + if len(nsubj) > 0 { + nsubj = append(subj, '.') + } + switch { + case r.fwc != nil: + // We've reached a full wildcard, do a FWC match on the stree at this point + // and don't keep iterating downward. + nsubj := append(nsubj, '>') + st.Match(nsubj, cb) + case r.pwc != nil: + // We've found a partial wildcard. We'll keep iterating downwards, but first + // check whether there's interest at this level (without triggering dupes) and + // match if so. + nsubj := append(nsubj, '*') + if len(r.pwc.psubs)+len(r.pwc.qsubs) > 0 && r.pwc.next != nil && r.pwc.next.numNodes() > 0 { + st.Match(nsubj, cb) + } + intersectStree(st, r.pwc.next, nsubj, cb) + case r.numNodes() > 0: + // Normal node with subject literals, keep iterating. + for t, n := range r.nodes { + nsubj := append(nsubj, t...) + intersectStree(st, n.next, nsubj, cb) + } + } +} diff --git a/server/sublist_test.go b/server/sublist_test.go index 2bc3a9958b8..5e8145c5ec7 100644 --- a/server/sublist_test.go +++ b/server/sublist_test.go @@ -26,6 +26,7 @@ import ( "testing" "time" + "github.com/nats-io/nats-server/v2/server/stree" "github.com/nats-io/nuid" ) @@ -1982,6 +1983,128 @@ func TestSublistNumInterest(t *testing.T) { sl.Remove(qsub) } +func TestSublistInterestBasedIntersection(t *testing.T) { + st := stree.NewSubjectTree[struct{}]() + st.Insert([]byte("one.two.three.four"), struct{}{}) + st.Insert([]byte("one.two.three.five"), struct{}{}) + st.Insert([]byte("one.two.six"), struct{}{}) + st.Insert([]byte("one.two.seven"), struct{}{}) + st.Insert([]byte("eight.nine"), struct{}{}) + + require_NoDuplicates := func(t *testing.T, got map[string]int) { + for _, c := range got { + require_Equal(t, c, 1) + } + } + + t.Run("Literals", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("one.two.six")) + sl.Insert(newSub("eight.nine")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 2) + require_NoDuplicates(t, got) + }) + + t.Run("PWC", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("one.two.*.*")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 2) + require_NoDuplicates(t, got) + }) + + t.Run("PWCOverlapping", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("one.two.*.four")) + sl.Insert(newSub("one.two.*.*")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 2) + require_NoDuplicates(t, got) + }) + + t.Run("PWCAll", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("*.*")) + sl.Insert(newSub("*.*.*")) + sl.Insert(newSub("*.*.*.*")) + require_True(t, sl.HasInterest("foo.bar")) + require_True(t, sl.HasInterest("foo.bar.baz")) + require_True(t, sl.HasInterest("foo.bar.baz.qux")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 5) + require_NoDuplicates(t, got) + }) + + t.Run("FWC", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("one.>")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 4) + require_NoDuplicates(t, got) + }) + + t.Run("FWCOverlapping", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("one.two.three.four")) + sl.Insert(newSub("one.>")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 4) + require_NoDuplicates(t, got) + }) + + t.Run("FWCAll", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub(">")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 5) + require_NoDuplicates(t, got) + }) + + t.Run("NoMatch", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("one")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 0) + }) + + t.Run("NoMatches", func(t *testing.T) { + got := map[string]int{} + sl := NewSublistNoCache() + sl.Insert(newSub("one")) + sl.Insert(newSub("eight")) + sl.Insert(newSub("ten")) + IntersectStree(st, sl, func(subj []byte, entry *struct{}) { + got[string(subj)]++ + }) + require_Len(t, len(got), 0) + }) +} + func subsInit(pre string, toks []string) { var sub string for _, t := range toks {