diff --git a/db.go b/db.go index 5d3e26496..13e9fdb93 100644 --- a/db.go +++ b/db.go @@ -1251,7 +1251,7 @@ func (db *DB) freepages() []common.Pgid { } reachable := make(map[common.Pgid]*common.Page) - nofreed := make(map[common.Pgid]bool) + nofreed := make(common.PgidSet) ech := make(chan error) go func() { for e := range ech { diff --git a/internal/common/page.go b/internal/common/page.go index 4453160bb..32dfb5bf3 100644 --- a/internal/common/page.go +++ b/internal/common/page.go @@ -389,3 +389,23 @@ func Mergepgids(dst, a, b Pgids) { // Append what's left in follow. _ = append(merged, follow...) } + +type PgidSet map[Pgid]struct{} + +func (s *PgidSet) Add(key Pgid) { + if *s == nil { + *s = make(map[Pgid]struct{}) + } + + (*s)[key] = struct{}{} +} + +func (s *PgidSet) Has(key Pgid) bool { + if *s == nil { + return false + } + + _, ok := (*s)[key] + + return ok +} diff --git a/internal/common/page_test.go b/internal/common/page_test.go index 376ab6a6c..fd9d9618b 100644 --- a/internal/common/page_test.go +++ b/internal/common/page_test.go @@ -70,3 +70,43 @@ func TestPgids_merge_quick(t *testing.T) { t.Fatal(err) } } + +func TestPgidSet_initialization(t *testing.T) { + var s PgidSet + if s != nil { + t.Fatal("Set must be nil") + } + + s.Add(0) + if s == nil { + t.Fatal("Set must be initialized") + } +} + +func TestPgidSet_set_has_added_values(t *testing.T) { + var s PgidSet + s.Add(100) + if len(s) != 1 || !s.Has(100) { + t.Fatal("Set must contain exactly one element") + } + + s.Add(200) + if len(s) != 2 || !s.Has(200) { + t.Fatal("Set must contain exactly two elements") + } +} + +func TestPgidSet_duplicates(t *testing.T) { + var s PgidSet + s.Add(5) + s.Add(5) + if len(s) != 1 { + t.Fatal("Set must still contain exactly one element after adding duplicate") + } + + s.Add(15) + s.Add(15) + if len(s) != 2 { + t.Fatal("Set must still contain exactly two elements after adding duplicate") + } +} diff --git a/internal/freelist/shared.go b/internal/freelist/shared.go index f30a69f10..333d5a43a 100644 --- a/internal/freelist/shared.go +++ b/internal/freelist/shared.go @@ -220,10 +220,10 @@ func (t *shared) Reload(p *common.Page) { func (t *shared) NoSyncReload(pgIds common.Pgids) { // Build a cache of only pending pages. - pcache := make(map[common.Pgid]struct{}) + pcache := make(common.PgidSet) for _, txp := range t.pending { for _, pendingID := range txp.ids { - pcache[pendingID] = struct{}{} + pcache.Add(pendingID) } } @@ -231,7 +231,7 @@ func (t *shared) NoSyncReload(pgIds common.Pgids) { // with any pages not in the pending lists. a := []common.Pgid{} for _, id := range pgIds { - if _, ok := pcache[id]; !ok { + if !pcache.Has(id) { a = append(a, id) } } diff --git a/tx_check.go b/tx_check.go index 59edf3573..6fc652712 100644 --- a/tx_check.go +++ b/tx_check.go @@ -40,14 +40,14 @@ func (tx *Tx) check(cfg checkConfig, ch chan error) { tx.db.loadFreelist() // Check if any pages are double freed. - freed := make(map[common.Pgid]bool) + freed := make(common.PgidSet) all := make([]common.Pgid, tx.db.freelist.Count()) tx.db.freelist.Copyall(all) for _, id := range all { - if freed[id] { + if freed.Has(id) { ch <- fmt.Errorf("page %d: already freed", id) } - freed[id] = true + freed.Add(id) } // Track every reachable page. @@ -68,7 +68,7 @@ func (tx *Tx) check(cfg checkConfig, ch chan error) { // Ensure all pages below high water mark are either reachable or freed. for i := common.Pgid(0); i < tx.meta.Pgid(); i++ { _, isReachable := reachable[i] - if !isReachable && !freed[i] { + if !isReachable && !freed.Has(i) { ch <- fmt.Errorf("page %d: unreachable unfreed", int(i)) } } @@ -83,13 +83,13 @@ func (tx *Tx) check(cfg checkConfig, ch chan error) { } } -func (tx *Tx) recursivelyCheckPage(pageId common.Pgid, reachable map[common.Pgid]*common.Page, freed map[common.Pgid]bool, +func (tx *Tx) recursivelyCheckPage(pageId common.Pgid, reachable map[common.Pgid]*common.Page, freed common.PgidSet, kvStringer KVStringer, ch chan error) { tx.checkInvariantProperties(pageId, reachable, freed, kvStringer, ch) tx.recursivelyCheckBucketInPage(pageId, reachable, freed, kvStringer, ch) } -func (tx *Tx) recursivelyCheckBucketInPage(pageId common.Pgid, reachable map[common.Pgid]*common.Page, freed map[common.Pgid]bool, +func (tx *Tx) recursivelyCheckBucketInPage(pageId common.Pgid, reachable map[common.Pgid]*common.Page, freed common.PgidSet, kvStringer KVStringer, ch chan error) { p := tx.page(pageId) @@ -120,7 +120,7 @@ func (tx *Tx) recursivelyCheckBucketInPage(pageId common.Pgid, reachable map[com } } -func (tx *Tx) recursivelyCheckBucket(b *Bucket, reachable map[common.Pgid]*common.Page, freed map[common.Pgid]bool, +func (tx *Tx) recursivelyCheckBucket(b *Bucket, reachable map[common.Pgid]*common.Page, freed common.PgidSet, kvStringer KVStringer, ch chan error) { // Ignore inline buckets. if b.RootPage() == 0 { @@ -138,7 +138,7 @@ func (tx *Tx) recursivelyCheckBucket(b *Bucket, reachable map[common.Pgid]*commo }) } -func (tx *Tx) checkInvariantProperties(pageId common.Pgid, reachable map[common.Pgid]*common.Page, freed map[common.Pgid]bool, +func (tx *Tx) checkInvariantProperties(pageId common.Pgid, reachable map[common.Pgid]*common.Page, freed common.PgidSet, kvStringer KVStringer, ch chan error) { tx.forEachPage(pageId, func(p *common.Page, _ int, stack []common.Pgid) { verifyPageReachable(p, tx.meta.Pgid(), stack, reachable, freed, ch) @@ -147,7 +147,7 @@ func (tx *Tx) checkInvariantProperties(pageId common.Pgid, reachable map[common. tx.recursivelyCheckPageKeyOrder(pageId, kvStringer.KeyToString, ch) } -func verifyPageReachable(p *common.Page, hwm common.Pgid, stack []common.Pgid, reachable map[common.Pgid]*common.Page, freed map[common.Pgid]bool, ch chan error) { +func verifyPageReachable(p *common.Page, hwm common.Pgid, stack []common.Pgid, reachable map[common.Pgid]*common.Page, freed common.PgidSet, ch chan error) { if p.Id() > hwm { ch <- fmt.Errorf("page %d: out of bounds: %d (stack: %v)", int(p.Id()), int(hwm), stack) } @@ -162,7 +162,7 @@ func verifyPageReachable(p *common.Page, hwm common.Pgid, stack []common.Pgid, r } // We should only encounter un-freed leaf and branch pages. - if freed[p.Id()] { + if freed.Has(p.Id()) { ch <- fmt.Errorf("page %d: reachable freed", int(p.Id())) } else if !p.IsBranchPage() && !p.IsLeafPage() { ch <- fmt.Errorf("page %d: invalid type: %s (stack: %v)", int(p.Id()), p.Typ(), stack)