diff --git a/server/v2/stf/branch/mergeiter.go b/server/v2/stf/branch/mergeiter.go index e71b88cffc42..108d19e7e041 100644 --- a/server/v2/stf/branch/mergeiter.go +++ b/server/v2/stf/branch/mergeiter.go @@ -7,229 +7,162 @@ import ( corestore "cosmossdk.io/core/store" ) +var ( + errInvalidIterator = errors.New("invalid iterator") +) + // mergedIterator merges a parent Iterator and a cache Iterator. -// The cache iterator may return nil keys to signal that an item -// had been deleted (but not deleted in the parent). -// If the cache iterator has the same key as the parent, the -// cache shadows (overrides) the parent. -type mergedIterator struct { - parent corestore.Iterator - cache corestore.Iterator - ascending bool - - valid bool +// The cache iterator may contain items that shadow or override items in the parent iterator. +// If the cache iterator has the same key as the parent, the cache's value takes precedence. +// Deleted items in the cache (indicated by nil values) are skipped. +type mergedIterator[Parent, Cache corestore.Iterator] struct { + parent Parent // Iterator for the parent store + cache Cache // Iterator for the cache store + ascending bool // Direction of iteration + valid bool // Indicates if the iterator is in a valid state + currKey []byte // Current key pointed by the iterator + currValue []byte // Current value corresponding to currKey + err error // Error encountered during iteration } -var _ corestore.Iterator = (*mergedIterator)(nil) +// Ensure mergedIterator implements the corestore.Iterator interface. +var _ corestore.Iterator = (*mergedIterator[corestore.Iterator, corestore.Iterator])(nil) -// mergeIterators merges two iterators. -func mergeIterators(parent, cache corestore.Iterator, ascending bool) corestore.Iterator { - iter := &mergedIterator{ +// mergeIterators creates a new merged iterator from parent and cache iterators. +// The 'ascending' parameter determines the direction of iteration. +func mergeIterators[Parent, Cache corestore.Iterator](parent Parent, cache Cache, ascending bool) *mergedIterator[Parent, Cache] { + iter := &mergedIterator[Parent, Cache]{ parent: parent, cache: cache, ascending: ascending, } - - iter.valid = iter.skipUntilExistsOrInvalid() + iter.advance() // Initialize the iterator by advancing to the first valid item return iter } -// Domain implements Iterator. -// Returns parent domain because cache and parent domains are the same. -func (iter *mergedIterator) Domain() (start, end []byte) { - return iter.parent.Domain() +// Domain returns the start and end range of the iterator. +// It delegates to the parent iterator as both iterators share the same domain. +func (i *mergedIterator[Parent, Cache]) Domain() (start, end []byte) { + return i.parent.Domain() } -// Valid implements Iterator. -func (iter *mergedIterator) Valid() bool { - return iter.valid +// Valid checks if the iterator is in a valid state. +// It returns true if the iterator has not reached the end. +func (i *mergedIterator[Parent, Cache]) Valid() bool { + return i.valid } -// Next implements Iterator -func (iter *mergedIterator) Next() { - iter.assertValid() - - switch { - case !iter.parent.Valid(): - // If parent is invalid, get the next cache item. - iter.cache.Next() - case !iter.cache.Valid(): - // If cache is invalid, get the next parent item. - iter.parent.Next() - default: - // Both are valid. Compare keys. - keyP, keyC := iter.parent.Key(), iter.cache.Key() - switch iter.compare(keyP, keyC) { - case -1: // parent < cache - iter.parent.Next() - case 0: // parent == cache - iter.parent.Next() - iter.cache.Next() - case 1: // parent > cache - iter.cache.Next() - } +// Next advances the iterator to the next valid item. +// It skips over deleted items (with nil values) and updates the current key and value. +func (i *mergedIterator[Parent, Cache]) Next() { + if !i.valid { + i.err = errInvalidIterator + return } - iter.valid = iter.skipUntilExistsOrInvalid() + i.advance() } -// Key implements Iterator -func (iter *mergedIterator) Key() []byte { - iter.assertValid() - - // If parent is invalid, get the cache key. - if !iter.parent.Valid() { - return iter.cache.Key() - } - - // If cache is invalid, get the parent key. - if !iter.cache.Valid() { - return iter.parent.Key() - } - - // Both are valid. Compare keys. - keyP, keyC := iter.parent.Key(), iter.cache.Key() - - cmp := iter.compare(keyP, keyC) - switch cmp { - case -1: // parent < cache - return keyP - case 0: // parent == cache - return keyP - case 1: // parent > cache - return keyC - default: - panic("invalid compare result") +// Key returns the current key pointed by the iterator. +// If the iterator is invalid, it returns nil. +func (i *mergedIterator[Parent, Cache]) Key() []byte { + if !i.valid { + panic("called key on invalid iterator") } + return i.currKey } -// Value implements Iterator -func (iter *mergedIterator) Value() []byte { - iter.assertValid() - - // If parent is invalid, get the cache value. - if !iter.parent.Valid() { - return iter.cache.Value() - } - - // If cache is invalid, get the parent value. - if !iter.cache.Valid() { - return iter.parent.Value() - } - - // Both are valid. Compare keys. - keyP, keyC := iter.parent.Key(), iter.cache.Key() - - cmp := iter.compare(keyP, keyC) - switch cmp { - case -1: // parent < cache - return iter.parent.Value() - case 0: // parent == cache - return iter.cache.Value() - case 1: // parent > cache - return iter.cache.Value() - default: - panic("invalid comparison result") +// Value returns the current value corresponding to the current key. +// If the iterator is invalid, it returns nil. +func (i *mergedIterator[Parent, Cache]) Value() []byte { + if !i.valid { + panic("called value on invalid iterator") } + return i.currValue } -// Close implements Iterator -func (iter *mergedIterator) Close() error { - err1 := iter.cache.Close() - if err := iter.parent.Close(); err != nil { - return err - } - - return err1 +// Close closes both the parent and cache iterators. +// It returns any error encountered during the closing of the iterators. +func (i *mergedIterator[Parent, Cache]) Close() (err error) { + err = errors.Join(err, i.parent.Close()) + err = errors.Join(err, i.cache.Close()) + i.valid = false + return err } -var errInvalidIterator = errors.New("invalid merged iterator") - -// Error returns an error if the mergedIterator is invalid defined by the -// Valid method. -func (iter *mergedIterator) Error() error { - if !iter.Valid() { - return errInvalidIterator - } - - return nil +// Error returns any error that occurred during iteration. +// If the iterator is valid, it returns nil. +func (i *mergedIterator[Parent, Cache]) Error() error { + return i.err } -// If not valid, panics. -// NOTE: May have side-effect of iterating over cache. -func (iter *mergedIterator) assertValid() { - if err := iter.Error(); err != nil { - panic(err) - } -} - -// Like bytes.Compare but opposite if not ascending. -func (iter *mergedIterator) compare(a, b []byte) int { - if iter.ascending { - return bytes.Compare(a, b) - } - - return bytes.Compare(a, b) * -1 -} - -// Skip all delete-items from the cache w/ `key < until`. After this function, -// current cache item is a non-delete-item, or `until <= key`. -// If the current cache item is not a delete item, does nothing. -// If `until` is nil, there is no limit, and cache may end up invalid. -// CONTRACT: cache is valid. -func (iter *mergedIterator) skipCacheDeletes(until []byte) { - for iter.cache.Valid() && - iter.cache.Value() == nil && - (until == nil || iter.compare(iter.cache.Key(), until) < 0) { - iter.cache.Next() - } -} - -// Fast forwards cache (or parent+cache in case of deleted items) until current -// item exists, or until iterator becomes invalid. -// Returns whether the iterator is valid. -func (iter *mergedIterator) skipUntilExistsOrInvalid() bool { +// advance moves the iterator to the next valid (non-deleted) item. +// It handles merging logic between the parent and cache iterators. +func (i *mergedIterator[Parent, Cache]) advance() { for { - // If parent is invalid, fast-forward cache. - if !iter.parent.Valid() { - iter.skipCacheDeletes(nil) - return iter.cache.Valid() + // Check if both iterators have reached the end + if !i.parent.Valid() && !i.cache.Valid() { + i.valid = false + return } - // Parent is valid. - if !iter.cache.Valid() { - return true + var key, value []byte + + // If parent iterator is exhausted, use the cache iterator + if !i.parent.Valid() { + key = i.cache.Key() + value = i.cache.Value() + i.cache.Next() + } else if !i.cache.Valid() { + // If cache iterator is exhausted, use the parent iterator + key = i.parent.Key() + value = i.parent.Value() + i.parent.Next() + } else { + // Both iterators are valid; compare keys + keyP, keyC := i.parent.Key(), i.cache.Key() + switch cmp := i.compare(keyP, keyC); { + case cmp < 0: + // Parent key is less than cache key + key = keyP + value = i.parent.Value() + i.parent.Next() + case cmp == 0: + // Keys are equal; cache overrides parent + key = keyC + value = i.cache.Value() + i.parent.Next() + i.cache.Next() + case cmp > 0: + // Cache key is less than parent key + key = keyC + value = i.cache.Value() + i.cache.Next() + } } - // Parent is valid, cache is valid. - - // Compare parent and cache. - keyP := iter.parent.Key() - keyC := iter.cache.Key() - switch iter.compare(keyP, keyC) { - case -1: // parent < cache. - return true + // Skip deleted items (value is nil) + if value == nil { + continue + } - case 0: // parent == cache. - // Skip over if cache item is a delete. - valueC := iter.cache.Value() - if valueC == nil { - iter.parent.Next() - iter.cache.Next() + // Update the current key and value, and mark iterator as valid + i.currKey = key + i.currValue = value + i.valid = true + return + } +} - continue - } - // Cache is not a delete. - - return true // cache exists. - case 1: // cache < parent - // Skip over if cache item is a delete. - valueC := iter.cache.Value() - if valueC == nil { - iter.skipCacheDeletes(keyP) - continue - } - // Cache is not a delete. - return true // cache exists. - } +// compare compares two byte slices a and b. +// It returns an integer comparing a and b: +// - Negative if a < b +// - Zero if a == b +// - Positive if a > b +// +// The comparison respects the iterator's direction (ascending or descending). +func (i *mergedIterator[Parent, Cache]) compare(a, b []byte) int { + if i.ascending { + return bytes.Compare(a, b) } + return bytes.Compare(b, a) } diff --git a/server/v2/stf/branch/mergeiter_test.go b/server/v2/stf/branch/mergeiter_test.go index 1a9e13bbe65c..45f10d4c1b36 100644 --- a/server/v2/stf/branch/mergeiter_test.go +++ b/server/v2/stf/branch/mergeiter_test.go @@ -7,6 +7,52 @@ import ( corestore "cosmossdk.io/core/store" ) +func TestMergedIterator_Validity(t *testing.T) { + panics := func(f func()) { + defer func() { + r := recover() + if r == nil { + t.Error("panic expected") + } + }() + + f() + } + + t.Run("panics when calling key on invalid iter", func(t *testing.T) { + parent, err := newMemState().Iterator(nil, nil) + if err != nil { + t.Fatal(err) + } + cache, err := newMemState().Iterator(nil, nil) + if err != nil { + t.Fatal(err) + } + + it := mergeIterators(parent, cache, true) + panics(func() { + it.Key() + }) + }) + + t.Run("panics when calling value on invalid iter", func(t *testing.T) { + parent, err := newMemState().Iterator(nil, nil) + if err != nil { + t.Fatal(err) + } + cache, err := newMemState().Iterator(nil, nil) + if err != nil { + t.Fatal(err) + } + + it := mergeIterators(parent, cache, true) + + panics(func() { + it.Value() + }) + }) +} + func TestMergedIterator_Next(t *testing.T) { specs := map[string]struct { setup func() corestore.Iterator