diff --git a/ledger/stateproofverificationtracker.go b/ledger/stateproofverificationtracker.go index 3e5da89ddc..ab071bbe21 100644 --- a/ledger/stateproofverificationtracker.go +++ b/ledger/stateproofverificationtracker.go @@ -66,6 +66,9 @@ type stateProofVerificationTracker struct { // log copied from ledger log logging.Logger + + // lastLookedUpVerificationData should store the last verification data that was looked up. + lastLookedUpVerificationData ledgercore.StateProofVerificationData } func (spt *stateProofVerificationTracker) loadFromDisk(l ledgerForTracker, _ basics.Round) error { @@ -81,6 +84,8 @@ func (spt *stateProofVerificationTracker) loadFromDisk(l ledgerForTracker, _ bas spt.stateProofVerificationMu.Lock() defer spt.stateProofVerificationMu.Unlock() + spt.lastLookedUpVerificationData = ledgercore.StateProofVerificationData{} + const initialDataArraySize = 10 spt.trackedCommitData = make([]verificationCommitData, 0, initialDataArraySize) spt.trackedDeleteData = make([]verificationDeleteData, 0, initialDataArraySize) @@ -139,7 +144,6 @@ func (spt *stateProofVerificationTracker) commitRound(ctx context.Context, tx *s } return err - } func (spt *stateProofVerificationTracker) postCommit(_ context.Context, dcc *deferredCommitContext) { @@ -165,10 +169,43 @@ func (spt *stateProofVerificationTracker) close() { } func (spt *stateProofVerificationTracker) LookupVerificationData(stateProofLastAttestedRound basics.Round) (*ledgercore.StateProofVerificationData, error) { + if lstlookup := spt.retrieveFromCache(stateProofLastAttestedRound); lstlookup != nil { + return lstlookup, nil + } + + verificationData, err := spt.lookUpVerificationData(stateProofLastAttestedRound) + if err != nil { + return nil, err + } + + spt.stateProofVerificationMu.Lock() + spt.lastLookedUpVerificationData = *verificationData + spt.stateProofVerificationMu.Unlock() + + return verificationData, nil +} + +func (spt *stateProofVerificationTracker) retrieveFromCache( + stateProofLastAttestedRound basics.Round) *ledgercore.StateProofVerificationData { + spt.stateProofVerificationMu.RLock() + defer spt.stateProofVerificationMu.RUnlock() + + if spt.lastLookedUpVerificationData.TargetStateProofRound == stateProofLastAttestedRound && + !spt.lastLookedUpVerificationData.MsgIsZero() { + cpy := spt.lastLookedUpVerificationData + + return &cpy + } + + return nil +} + +func (spt *stateProofVerificationTracker) lookUpVerificationData(stateProofLastAttestedRound basics.Round) (*ledgercore.StateProofVerificationData, error) { spt.stateProofVerificationMu.RLock() defer spt.stateProofVerificationMu.RUnlock() - if len(spt.trackedCommitData) > 0 && stateProofLastAttestedRound >= spt.trackedCommitData[0].verificationData.TargetStateProofRound && + if len(spt.trackedCommitData) > 0 && + stateProofLastAttestedRound >= spt.trackedCommitData[0].verificationData.TargetStateProofRound && stateProofLastAttestedRound <= spt.trackedCommitData[len(spt.trackedCommitData)-1].verificationData.TargetStateProofRound { return spt.lookupDataInTrackedMemory(stateProofLastAttestedRound) } @@ -208,11 +245,11 @@ func (spt *stateProofVerificationTracker) committedRoundToLatestCommitDataIndex( latestCommittedDataIndex := -1 for index, data := range spt.trackedCommitData { - if data.confirmedRound <= committedRound { - latestCommittedDataIndex = index - } else { + if data.confirmedRound > committedRound { break } + + latestCommittedDataIndex = index } return latestCommittedDataIndex @@ -222,16 +259,25 @@ func (spt *stateProofVerificationTracker) committedRoundToLatestDeleteDataIndex( latestCommittedDataIndex := -1 for index, data := range spt.trackedDeleteData { - if data.confirmedRound <= committedRound { - latestCommittedDataIndex = index - } else { + if data.confirmedRound > committedRound { break } + + latestCommittedDataIndex = index } return latestCommittedDataIndex } +func getVerificationData(blk *bookkeeping.Block) ledgercore.StateProofVerificationData { + return ledgercore.StateProofVerificationData{ + VotersCommitment: blk.StateProofTracking[protocol.StateProofBasic].StateProofVotersCommitment, + OnlineTotalWeight: blk.StateProofTracking[protocol.StateProofBasic].StateProofOnlineTotalWeight, + TargetStateProofRound: blk.Round() + basics.Round(blk.ConsensusProtocol().StateProofInterval), + Version: blk.CurrentProtocol, + } +} + func (spt *stateProofVerificationTracker) insertCommitData(blk *bookkeeping.Block) { spt.stateProofVerificationMu.Lock() defer spt.stateProofVerificationMu.Unlock() @@ -244,16 +290,9 @@ func (spt *stateProofVerificationTracker) insertCommitData(blk *bookkeeping.Bloc } } - verificationData := ledgercore.StateProofVerificationData{ - VotersCommitment: blk.StateProofTracking[protocol.StateProofBasic].StateProofVotersCommitment, - OnlineTotalWeight: blk.StateProofTracking[protocol.StateProofBasic].StateProofOnlineTotalWeight, - TargetStateProofRound: blk.Round() + basics.Round(blk.ConsensusProtocol().StateProofInterval), - Version: blk.CurrentProtocol, - } - commitData := verificationCommitData{ confirmedRound: blk.Round(), - verificationData: verificationData, + verificationData: getVerificationData(blk), } spt.trackedCommitData = append(spt.trackedCommitData, commitData) diff --git a/ledger/stateproofverificationtracker_test.go b/ledger/stateproofverificationtracker_test.go index d6bb6362ed..995899279c 100644 --- a/ledger/stateproofverificationtracker_test.go +++ b/ledger/stateproofverificationtracker_test.go @@ -32,6 +32,7 @@ import ( const defaultStateProofInterval = uint64(256) const defaultFirstStateProofDataRound = basics.Round(defaultStateProofInterval * 2) +const defaultFirstStateProofDataInterval = basics.Round(2) const unusedByStateProofTracker = basics.Round(0) type StateProofTrackingLocation uint64 @@ -230,6 +231,7 @@ func TestStateProofVerificationTracker_CommitFUllDbFlush(t *testing.T) { mockCommit(t, spt, ml, 0, lastBlock.block.Round()) + spt.lastLookedUpVerificationData = ledgercore.StateProofVerificationData{} verifyStateProofVerificationTracking(t, spt, defaultFirstStateProofDataRound, expectedDataNum, defaultStateProofInterval, false, trackerMemory) verifyStateProofVerificationTracking(t, spt, defaultFirstStateProofDataRound, expectedDataNum, defaultStateProofInterval, true, trackerDB) } @@ -443,6 +445,7 @@ func TestStateProofVerificationTracker_LookupVerificationData(t *testing.T) { // This error shouldn't happen in normal flow - we force it to happen for the test. spt.trackedCommitData[0].verificationData.TargetStateProofRound = 0 + spt.lastLookedUpVerificationData = ledgercore.StateProofVerificationData{} _, err = spt.LookupVerificationData(memoryDataRound) a.ErrorIs(err, errStateProofVerificationDataNotFound) a.ErrorContains(err, "memory lookup failed") @@ -463,3 +466,28 @@ func TestStateProofVerificationTracker_PanicInvalidBlockInsertion(t *testing.T) pastBlock := randomBlock(0) a.Panics(func() { spt.insertCommitData(&pastBlock.block) }) } + +func TestStateProofVerificationTracker_lastLookupDataUpdatedAfterLookup(t *testing.T) { + partitiontest.PartitionTest(t) + a := require.New(t) + + mockLedger, spt := initializeLedgerSpt(t) + defer mockLedger.Close() + defer spt.close() + + a.Empty(spt.lastLookedUpVerificationData) + + NumberOfVerificationDataToAdd := uint64(10) + _ = feedBlocksUpToRound(spt, genesisBlock(), basics.Round(NumberOfVerificationDataToAdd*defaultStateProofInterval), + defaultStateProofInterval, true) + + a.Empty(spt.lastLookedUpVerificationData) + + expectedDataInDbNum := NumberOfVerificationDataToAdd + for i := uint64(defaultFirstStateProofDataInterval); i < expectedDataInDbNum; i++ { + vf, err := spt.LookupVerificationData(basics.Round(defaultStateProofInterval * i)) + a.NoError(err) + + a.Equal(*vf, spt.lastLookedUpVerificationData) + } +}