diff --git a/lib/backend/backend.go b/lib/backend/backend.go index 1b243ffc48758..286550189584d 100644 --- a/lib/backend/backend.go +++ b/lib/backend/backend.go @@ -167,18 +167,21 @@ func IterateRange(ctx context.Context, bk Backend, startKey, endKey Key, limit i // 2. allow individual backends to expose custom streaming methods s.t. the most performant // impl for a given backend may be used. func StreamRange(ctx context.Context, bk Backend, startKey, endKey Key, pageSize int) stream.Stream[Item] { - return stream.PageFunc[Item](func() ([]Item, error) { - if startKey == nil { + var done bool + return stream.PageFunc(func() ([]Item, error) { + if done { return nil, io.EOF } - rslt, err := bk.GetRange(ctx, startKey, endKey, pageSize) + rslt, err := bk.GetRange(ctx, startKey, endKey, pageSize+1) if err != nil { return nil, trace.Wrap(err) } - if len(rslt.Items) < pageSize { - startKey = nil + if len(rslt.Items) > pageSize { + startKey = rslt.Items[pageSize].Key + clear(rslt.Items[pageSize:]) + rslt.Items = rslt.Items[:pageSize] } else { - startKey = nextKey(rslt.Items[pageSize-1].Key) + done = true } return rslt.Items, nil }) @@ -298,10 +301,10 @@ func (p Params) GetString(key string) string { // NoLimit specifies no limits const NoLimit = 0 -// nextKey returns the next possible key. -// If used with a key prefix, this will return -// the end of the range for that key prefix. -func nextKey(key Key) Key { +const noEnd = "\x00" + +// RangeEnd returns end of the range for given key. +func RangeEnd(key Key) Key { end := make([]byte, len(key)) copy(end, key) for i := len(end) - 1; i >= 0; i-- { @@ -312,14 +315,7 @@ func nextKey(key Key) Key { } } // next key does not exist (e.g., 0xffff); - return noEnd -} - -var noEnd = Key{0} - -// RangeEnd returns end of the range for given key. -func RangeEnd(key Key) Key { - return nextKey(key) + return Key(noEnd) } // HostID is a derivation of a KeyedItem that allows the host id @@ -339,7 +335,7 @@ type KeyedItem interface { // have the HostID part. func NextPaginationKey(ki KeyedItem) string { key := GetPaginationKey(ki) - return string(nextKey(Key(key))) + return string(RangeEnd(Key(key))) } // GetPaginationKey returns the pagination key given item. diff --git a/lib/backend/memory/memory_test.go b/lib/backend/memory/memory_test.go index 625baf9aadf2c..8103653b25b6f 100644 --- a/lib/backend/memory/memory_test.go +++ b/lib/backend/memory/memory_test.go @@ -22,6 +22,7 @@ import ( "context" "fmt" "os" + "slices" "strings" "testing" @@ -127,3 +128,31 @@ func TestIterateRange(t *testing.T) { require.NoError(t, err) require.Equal(t, 20, scount) } + +func TestStreamRange(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m, err := New(Config{}) + require.NoError(t, err) + defer m.Close() + + const N = 10 + for i := range 10 * N { + _, err := m.Put(ctx, backend.Item{ + Key: backend.NewKey("foo", strings.Repeat("a", i+1)), + Value: []byte("\x00"), + }) + require.NoError(t, err) + } + + var items []string + st := backend.StreamRange(ctx, m, backend.ExactKey("foo"), backend.RangeEnd(backend.ExactKey("foo")), N) + for st.Next() { + items = append(items, st.Item().Key.String()) + } + require.NoError(t, st.Done()) + + require.Len(t, items, 10*N) + require.True(t, slices.IsSorted(items)) +}