diff --git a/ledger/accountdb.go b/ledger/accountdb.go index 840ab5e52f..49db9d3377 100644 --- a/ledger/accountdb.go +++ b/ledger/accountdb.go @@ -24,7 +24,6 @@ import ( "errors" "fmt" "math" - "strings" "time" "github.com/mattn/go-sqlite3" @@ -36,8 +35,6 @@ import ( "github.com/algorand/go-algorand/crypto/merklesignature" "github.com/algorand/go-algorand/crypto/merkletrie" "github.com/algorand/go-algorand/data/basics" - "github.com/algorand/go-algorand/data/bookkeeping" - "github.com/algorand/go-algorand/data/transactions" "github.com/algorand/go-algorand/ledger/ledgercore" "github.com/algorand/go-algorand/ledger/store" "github.com/algorand/go-algorand/logging" @@ -1190,7 +1187,8 @@ func accountsInit(tx *sql.Tx, initAccounts map[basics.Address]basics.AccountData return true, fmt.Errorf("overflow computing totals") } - err = accountsPutTotals(tx, totals, false) + arw := store.NewAccountsSQLReaderWriter(tx) + err = arw.AccountsPutTotals(totals, false) if err != nil { return true, err } @@ -1463,7 +1461,8 @@ func performResourceTableMigration(ctx context.Context, tx *sql.Tx, log func(pro var processedAccounts uint64 var totalBaseAccounts uint64 - totalBaseAccounts, err = totalAccounts(ctx, tx) + arw := store.NewAccountsSQLReaderWriter(tx) + totalBaseAccounts, err = arw.TotalAccounts(ctx) if err != nil { return err } @@ -1542,7 +1541,8 @@ func performTxTailTableMigration(ctx context.Context, tx *sql.Tx, blockDb db.Acc return nil } - dbRound, err := accountsRound(tx) + arw := store.NewAccountsSQLReaderWriter(tx) + dbRound, err := arw.AccountsRound() if err != nil { return fmt.Errorf("latest block number cannot be retrieved : %w", err) } @@ -1585,27 +1585,28 @@ func performTxTailTableMigration(ctx context.Context, tx *sql.Tx, blockDb db.Acc return fmt.Errorf("block for round %d ( %d - %d ) cannot be retrieved : %w", rnd, firstRound, dbRound, err) } - tail, err := txTailRoundFromBlock(blk) + tail, err := store.TxTailRoundFromBlock(blk) if err != nil { return err } - encodedTail, _ := tail.encode() + encodedTail, _ := tail.Encode() tailRounds = append(tailRounds, encodedTail) } - return txtailNewRound(ctx, tx, firstRound, tailRounds, firstRound) + return arw.TxtailNewRound(ctx, firstRound, tailRounds, firstRound) }) return err } func performOnlineRoundParamsTailMigration(ctx context.Context, tx *sql.Tx, blockDb db.Accessor, newDatabase bool, initProto protocol.ConsensusVersion) (err error) { - totals, err := accountsTotals(ctx, tx, false) + arw := store.NewAccountsSQLReaderWriter(tx) + totals, err := arw.AccountsTotals(ctx, false) if err != nil { return err } - rnd, err := accountsRound(tx) + rnd, err := arw.AccountsRound() if err != nil { return err } @@ -1664,7 +1665,8 @@ func performOnlineAccountsTableMigration(ctx context.Context, tx *sql.Tx, progre var processedAccounts uint64 var totalOnlineBaseAccounts uint64 - totalOnlineBaseAccounts, err = totalAccounts(ctx, tx) + arw := store.NewAccountsSQLReaderWriter(tx) + totalOnlineBaseAccounts, err = arw.TotalAccounts(ctx) var total uint64 err = tx.QueryRowContext(ctx, "SELECT count(1) FROM accountbase").Scan(&total) if err != nil { @@ -1888,159 +1890,6 @@ func accountsReset(ctx context.Context, tx *sql.Tx) error { return err } -// accountsRound returns the tracker balances round number -func accountsRound(q db.Queryable) (rnd basics.Round, err error) { - err = q.QueryRow("SELECT rnd FROM acctrounds WHERE id='acctbase'").Scan(&rnd) - if err != nil { - return - } - return -} - -// accountsHashRound returns the round of the hash tree -// if the hash of the tree doesn't exists, it returns zero. -func accountsHashRound(ctx context.Context, tx *sql.Tx) (hashrnd basics.Round, err error) { - err = tx.QueryRowContext(ctx, "SELECT rnd FROM acctrounds WHERE id='hashbase'").Scan(&hashrnd) - if err == sql.ErrNoRows { - hashrnd = basics.Round(0) - err = nil - } - return -} - -// accountsOnlineTop returns the top n online accounts starting at position offset -// (that is, the top offset'th account through the top offset+n-1'th account). -// -// The accounts are sorted by their normalized balance and address. The normalized -// balance has to do with the reward parts of online account balances. See the -// normalization procedure in AccountData.NormalizedOnlineBalance(). -// -// Note that this does not check if the accounts have a vote key valid for any -// particular round (past, present, or future). -func accountsOnlineTop(tx *sql.Tx, rnd basics.Round, offset uint64, n uint64, proto config.ConsensusParams) (map[basics.Address]*ledgercore.OnlineAccount, error) { - // onlineaccounts has historical data ordered by updround for both online and offline accounts. - // This means some account A might have norm balance != 0 at round N and norm balance == 0 at some round K > N. - // For online top query one needs to find entries not fresher than X with norm balance != 0. - // To do that the query groups row by address and takes the latest updround, and then filters out rows with zero nor balance. - rows, err := tx.Query(`SELECT address, normalizedonlinebalance, data, max(updround) FROM onlineaccounts -WHERE updround <= ? -GROUP BY address HAVING normalizedonlinebalance > 0 -ORDER BY normalizedonlinebalance DESC, address DESC LIMIT ? OFFSET ?`, rnd, n, offset) - - if err != nil { - return nil, err - } - defer rows.Close() - - res := make(map[basics.Address]*ledgercore.OnlineAccount, n) - for rows.Next() { - var addrbuf []byte - var buf []byte - var normBal sql.NullInt64 - var updround sql.NullInt64 - err = rows.Scan(&addrbuf, &normBal, &buf, &updround) - if err != nil { - return nil, err - } - - var data store.BaseOnlineAccountData - err = protocol.Decode(buf, &data) - if err != nil { - return nil, err - } - - var addr basics.Address - if len(addrbuf) != len(addr) { - err = fmt.Errorf("account DB address length mismatch: %d != %d", len(addrbuf), len(addr)) - return nil, err - } - - if !normBal.Valid { - return nil, fmt.Errorf("non valid norm balance for online account %s", addr.String()) - } - - copy(addr[:], addrbuf) - // TODO: figure out protocol to use for rewards - // The original implementation uses current proto to recalculate norm balance - // In the same time, in accountsNewRound genesis protocol is used to fill norm balance value - // In order to be consistent with the original implementation recalculate the balance with current proto - normBalance := basics.NormalizedOnlineAccountBalance(basics.Online, data.RewardsBase, data.MicroAlgos, proto) - oa := data.GetOnlineAccount(addr, normBalance) - res[addr] = &oa - } - - return res, rows.Err() -} - -func onlineAccountsAll(tx *sql.Tx, maxAccounts uint64) ([]store.PersistedOnlineAccountData, error) { - rows, err := tx.Query("SELECT rowid, address, updround, data FROM onlineaccounts ORDER BY address, updround ASC") - if err != nil { - return nil, err - } - defer rows.Close() - - result := make([]store.PersistedOnlineAccountData, 0, maxAccounts) - var numAccounts uint64 - seenAddr := make([]byte, len(basics.Address{})) - for rows.Next() { - var addrbuf []byte - var buf []byte - data := store.PersistedOnlineAccountData{} - err := rows.Scan(&data.Rowid, &addrbuf, &data.UpdRound, &buf) - if err != nil { - return nil, err - } - if len(addrbuf) != len(data.Addr) { - err = fmt.Errorf("account DB address length mismatch: %d != %d", len(addrbuf), len(data.Addr)) - return nil, err - } - if maxAccounts > 0 { - if !bytes.Equal(seenAddr, addrbuf) { - numAccounts++ - if numAccounts > maxAccounts { - break - } - copy(seenAddr, addrbuf) - } - } - copy(data.Addr[:], addrbuf) - err = protocol.Decode(buf, &data.AccountData) - if err != nil { - return nil, err - } - result = append(result, data) - } - return result, nil -} - -func accountsTotals(ctx context.Context, q db.Queryable, catchpointStaging bool) (totals ledgercore.AccountTotals, err error) { - id := "" - if catchpointStaging { - id = "catchpointStaging" - } - row := q.QueryRowContext(ctx, "SELECT online, onlinerewardunits, offline, offlinerewardunits, notparticipating, notparticipatingrewardunits, rewardslevel FROM accounttotals WHERE id=?", id) - err = row.Scan(&totals.Online.Money.Raw, &totals.Online.RewardUnits, - &totals.Offline.Money.Raw, &totals.Offline.RewardUnits, - &totals.NotParticipating.Money.Raw, &totals.NotParticipating.RewardUnits, - &totals.RewardsLevel) - - return -} - -func accountsPutTotals(tx *sql.Tx, totals ledgercore.AccountTotals, catchpointStaging bool) error { - id := "" - if catchpointStaging { - id = "catchpointStaging" - } - _, err := tx.Exec("REPLACE INTO accounttotals (id, online, onlinerewardunits, offline, offlinerewardunits, notparticipating, notparticipatingrewardunits, rewardslevel) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - id, - totals.Online.Money.Raw, totals.Online.RewardUnits, - totals.Offline.Money.Raw, totals.Offline.RewardUnits, - totals.NotParticipating.Money.Raw, totals.NotParticipating.RewardUnits, - totals.RewardsLevel) - return err -} - func accountsOnlineRoundParams(tx *sql.Tx) (onlineRoundParamsData []ledgercore.OnlineRoundParamsData, endRound basics.Round, err error) { rows, err := tx.Query("SELECT rnd, data FROM onlineroundparamstail ORDER BY rnd ASC") if err != nil { @@ -2458,144 +2307,6 @@ func onlineAccountsNewRoundImpl( return } -func rowidsToChunkedArgs(rowids []int64) [][]interface{} { - const sqliteMaxVariableNumber = 999 - - numChunks := len(rowids)/sqliteMaxVariableNumber + 1 - if len(rowids)%sqliteMaxVariableNumber == 0 { - numChunks-- - } - chunks := make([][]interface{}, numChunks) - if numChunks == 1 { - // optimize memory consumption for the most common case - chunks[0] = make([]interface{}, len(rowids)) - for i, rowid := range rowids { - chunks[0][i] = interface{}(rowid) - } - } else { - for i := 0; i < numChunks; i++ { - chunkSize := sqliteMaxVariableNumber - if i == numChunks-1 { - chunkSize = len(rowids) - (numChunks-1)*sqliteMaxVariableNumber - } - chunks[i] = make([]interface{}, chunkSize) - } - for i, rowid := range rowids { - chunkIndex := i / sqliteMaxVariableNumber - chunks[chunkIndex][i%sqliteMaxVariableNumber] = interface{}(rowid) - } - } - return chunks -} - -func onlineAccountsDeleteByRowIDs(tx *sql.Tx, rowids []int64) (err error) { - if len(rowids) == 0 { - return - } - - // sqlite3 < 3.32.0 allows SQLITE_MAX_VARIABLE_NUMBER = 999 bindings - // see https://www.sqlite.org/limits.html - // rowids might be larger => split to chunks are remove - chunks := rowidsToChunkedArgs(rowids) - for _, chunk := range chunks { - _, err = tx.Exec("DELETE FROM onlineaccounts WHERE rowid IN (?"+strings.Repeat(",?", len(chunk)-1)+")", chunk...) - if err != nil { - return - } - } - return -} - -// onlineAccountsDelete deleted entries with updRound <= expRound -func onlineAccountsDelete(tx *sql.Tx, forgetBefore basics.Round) (err error) { - rows, err := tx.Query("SELECT rowid, address, updRound, data FROM onlineaccounts WHERE updRound < ? ORDER BY address, updRound DESC", forgetBefore) - if err != nil { - return err - } - defer rows.Close() - - var rowids []int64 - var rowid sql.NullInt64 - var updRound sql.NullInt64 - var buf []byte - var addrbuf []byte - - var prevAddr []byte - - for rows.Next() { - err = rows.Scan(&rowid, &addrbuf, &updRound, &buf) - if err != nil { - return err - } - if !rowid.Valid || !updRound.Valid { - return fmt.Errorf("onlineAccountsDelete: invalid rowid or updRound") - } - if len(addrbuf) != len(basics.Address{}) { - err = fmt.Errorf("account DB address length mismatch: %d != %d", len(addrbuf), len(basics.Address{})) - return - } - - if !bytes.Equal(addrbuf, prevAddr) { - // new address - // if the first (latest) entry is - // - offline then delete all - // - online then safe to delete all previous except this first (latest) - - // reset the state - prevAddr = addrbuf - - var oad store.BaseOnlineAccountData - err = protocol.Decode(buf, &oad) - if err != nil { - return - } - if oad.IsVotingEmpty() { - // delete this and all subsequent - rowids = append(rowids, rowid.Int64) - } - - // restart the loop - // if there are some subsequent entries, they will deleted on the next iteration - // if no subsequent entries, the loop will reset the state and the latest entry does not get deleted - continue - } - // delete all subsequent entries - rowids = append(rowids, rowid.Int64) - } - - return onlineAccountsDeleteByRowIDs(tx, rowids) -} - -// updates the round number associated with the current account data. -func updateAccountsRound(tx *sql.Tx, rnd basics.Round) (err error) { - res, err := tx.Exec("UPDATE acctrounds SET rnd=? WHERE id='acctbase' AND rnd rnd { - err = fmt.Errorf("newRound %d is not after base %d", rnd, base) - return - } else if base != rnd { - err = fmt.Errorf("updateAccountsRound(acctbase, %d): expected to update 1 row but got %d", rnd, aff) - return - } - } - return -} - // updates the round number associated with the hash of current account data. func updateAccountsHashRound(ctx context.Context, tx *sql.Tx, hashRound basics.Round) (err error) { res, err := tx.ExecContext(ctx, "INSERT OR REPLACE INTO acctrounds(id,rnd) VALUES('hashbase',?)", hashRound) @@ -2615,27 +2326,6 @@ func updateAccountsHashRound(ctx context.Context, tx *sql.Tx, hashRound basics.R return } -// totalAccounts returns the total number of accounts -func totalAccounts(ctx context.Context, tx *sql.Tx) (total uint64, err error) { - err = tx.QueryRowContext(ctx, "SELECT count(1) FROM accountbase").Scan(&total) - if err == sql.ErrNoRows { - total = 0 - err = nil - return - } - return -} - -func totalKVs(ctx context.Context, tx *sql.Tx) (total uint64, err error) { - err = tx.QueryRowContext(ctx, "SELECT count(1) FROM kvstore").Scan(&total) - if err == sql.ErrNoRows { - total = 0 - err = nil - return - } - return -} - // reencodeAccounts reads all the accounts in the accountbase table, decode and reencode the account data. // if the account data is found to have a different encoding, it would update the encoded account on disk. // on return, it returns the number of modified accounts as well as an error ( if we had any ) @@ -3525,117 +3215,6 @@ func (iterator *catchpointPendingHashesIterator) Close() { } } -// txTailRoundLease is used as part of txTailRound for storing -// a single lease. -type txTailRoundLease struct { - _struct struct{} `codec:",omitempty,omitemptyarray"` - - Sender basics.Address `codec:"s"` - Lease [32]byte `codec:"l,allocbound=-"` - TxnIdx uint64 `code:"i"` //!-- index of the entry in TxnIDs/LastValid -} - -// TxTailRound contains the information about a single round of transactions. -// The TxnIDs and LastValid would both be of the same length, and are stored -// in that way for efficient message=pack encoding. The Leases would point to the -// respective transaction index. Note that this isn’t optimized for storing -// leases, as leases are extremely rare. -type txTailRound struct { - _struct struct{} `codec:",omitempty,omitemptyarray"` - - TxnIDs []transactions.Txid `codec:"i,allocbound=-"` - LastValid []basics.Round `codec:"v,allocbound=-"` - Leases []txTailRoundLease `codec:"l,allocbound=-"` - Hdr bookkeeping.BlockHeader `codec:"h,allocbound=-"` -} - -// encode the transaction tail data into a serialized form, and return the serialized data -// as well as the hash of the data. -func (t *txTailRound) encode() ([]byte, crypto.Digest) { - tailData := protocol.Encode(t) - hash := crypto.Hash(tailData) - return tailData, hash -} - -func txTailRoundFromBlock(blk bookkeeping.Block) (*txTailRound, error) { - payset, err := blk.DecodePaysetFlat() - if err != nil { - return nil, err - } - - tail := &txTailRound{} - - tail.TxnIDs = make([]transactions.Txid, len(payset)) - tail.LastValid = make([]basics.Round, len(payset)) - tail.Hdr = blk.BlockHeader - - for txIdxtxid, txn := range payset { - tail.TxnIDs[txIdxtxid] = txn.ID() - tail.LastValid[txIdxtxid] = txn.Txn.LastValid - if txn.Txn.Lease != [32]byte{} { - tail.Leases = append(tail.Leases, txTailRoundLease{ - Sender: txn.Txn.Sender, - Lease: txn.Txn.Lease, - TxnIdx: uint64(txIdxtxid), - }) - } - } - return tail, nil -} - -func txtailNewRound(ctx context.Context, tx *sql.Tx, baseRound basics.Round, roundData [][]byte, forgetBeforeRound basics.Round) error { - insertStmt, err := tx.PrepareContext(ctx, "INSERT INTO txtail(rnd, data) VALUES(?, ?)") - if err != nil { - return err - } - defer insertStmt.Close() - - for i, data := range roundData { - _, err = insertStmt.ExecContext(ctx, int(baseRound)+i, data[:]) - if err != nil { - return err - } - } - - _, err = tx.ExecContext(ctx, "DELETE FROM txtail WHERE rnd < ?", forgetBeforeRound) - return err -} - -func loadTxTail(ctx context.Context, tx *sql.Tx, dbRound basics.Round) (roundData []*txTailRound, roundHash []crypto.Digest, baseRound basics.Round, err error) { - rows, err := tx.QueryContext(ctx, "SELECT rnd, data FROM txtail ORDER BY rnd DESC") - if err != nil { - return nil, nil, 0, err - } - defer rows.Close() - - expectedRound := dbRound - for rows.Next() { - var round basics.Round - var data []byte - err = rows.Scan(&round, &data) - if err != nil { - return nil, nil, 0, err - } - if round != expectedRound { - return nil, nil, 0, fmt.Errorf("txtail table contain unexpected round %d; round %d was expected", round, expectedRound) - } - tail := &txTailRound{} - err = protocol.Decode(data, tail) - if err != nil { - return nil, nil, 0, err - } - roundData = append(roundData, tail) - roundHash = append(roundHash, crypto.Hash(data)) - expectedRound-- - } - // reverse the array ordering in-place so that it would be incremental order. - for i := 0; i < len(roundData)/2; i++ { - roundData[i], roundData[len(roundData)-i-1] = roundData[len(roundData)-i-1], roundData[i] - roundHash[i], roundHash[len(roundHash)-i-1] = roundHash[len(roundHash)-i-1], roundHash[i] - } - return roundData, roundHash, expectedRound + 1, nil -} - // For the `catchpointfirststageinfo` table. type catchpointFirstStageInfo struct { _struct struct{} `codec:",omitempty,omitemptyarray"` diff --git a/ledger/accountdb_test.go b/ledger/accountdb_test.go index 6b2098e3a9..eca434a6c4 100644 --- a/ledger/accountdb_test.go +++ b/ledger/accountdb_test.go @@ -93,7 +93,9 @@ func accountsInitTest(tb testing.TB, tx *sql.Tx, initAccounts map[basics.Address } func checkAccounts(t *testing.T, tx *sql.Tx, rnd basics.Round, accts map[basics.Address]basics.AccountData) { - r, err := accountsRound(tx) + arw := store.NewAccountsSQLReaderWriter(tx) + + r, err := arw.AccountsRound() require.NoError(t, err) require.Equal(t, r, rnd) @@ -126,7 +128,7 @@ func checkAccounts(t *testing.T, tx *sql.Tx, rnd basics.Round, accts map[basics. require.NoError(t, err) require.Equal(t, all, accts) - totals, err := accountsTotals(context.Background(), tx, false) + totals, err := arw.AccountsTotals(context.Background(), false) require.NoError(t, err) require.Equal(t, totalOnline, totals.Online.Money.Raw, "mismatching total online money") require.Equal(t, totalOffline, totals.Offline.Money.Raw) @@ -168,7 +170,7 @@ func checkAccounts(t *testing.T, tx *sql.Tx, rnd basics.Round, accts map[basics. }) for i := 0; i < len(onlineAccounts); i++ { - dbtop, err := accountsOnlineTop(tx, rnd, 0, uint64(i), proto) + dbtop, err := arw.AccountsOnlineTop(rnd, 0, uint64(i), proto) require.NoError(t, err) require.Equal(t, i, len(dbtop)) @@ -178,7 +180,7 @@ func checkAccounts(t *testing.T, tx *sql.Tx, rnd basics.Round, accts map[basics. } } - top, err := accountsOnlineTop(tx, rnd, 0, uint64(len(onlineAccounts)+1), proto) + top, err := arw.AccountsOnlineTop(rnd, 0, uint64(len(onlineAccounts)+1), proto) require.NoError(t, err) require.Equal(t, len(top), len(onlineAccounts)) } @@ -257,10 +259,12 @@ func TestAccountDBRound(t *testing.T) { require.NoError(t, err) defer tx.Rollback() + arw := store.NewAccountsSQLReaderWriter(tx) + accts := ledgertesting.RandomAccounts(20, true) accountsInitTest(t, tx, accts, protocol.ConsensusCurrentVersion) checkAccounts(t, tx, 0, accts) - totals, err := accountsTotals(context.Background(), tx, false) + totals, err := arw.AccountsTotals(context.Background(), false) require.NoError(t, err) expectedOnlineRoundParams, endRound, err := accountsOnlineRoundParams(tx) require.NoError(t, err) @@ -308,7 +312,7 @@ func TestAccountDBRound(t *testing.T) { err = resourceUpdatesCnt.resourcesLoadOld(tx, knownAddresses) require.NoError(t, err) - err = accountsPutTotals(tx, totals, false) + err = arw.AccountsPutTotals(totals, false) require.NoError(t, err) onlineRoundParams := ledgercore.OnlineRoundParamsData{RewardsLevel: totals.RewardsLevel, OnlineSupply: totals.Online.Money.Raw, CurrentProtocol: protocol.ConsensusCurrentVersion} err = accountsPutOnlineRoundParams(tx, []ledgercore.OnlineRoundParamsData{onlineRoundParams}, basics.Round(i)) @@ -328,7 +332,7 @@ func TestAccountDBRound(t *testing.T) { updatedOnlineAccts, err := onlineAccountsNewRound(tx, updatesOnlineCnt, proto, basics.Round(i)) require.NoError(t, err) - err = updateAccountsRound(tx, basics.Round(i)) + err = arw.UpdateAccountsRound(basics.Round(i)) require.NoError(t, err) // TODO: calculate exact number of updates? @@ -346,7 +350,7 @@ func TestAccountDBRound(t *testing.T) { } expectedTotals := ledgertesting.CalculateNewRoundAccountTotals(t, updates, 0, proto, nil, ledgercore.AccountTotals{}) - actualTotals, err := accountsTotals(context.Background(), tx, false) + actualTotals, err := arw.AccountsTotals(context.Background(), false) require.NoError(t, err) require.Equal(t, expectedTotals, actualTotals) @@ -2348,9 +2352,11 @@ func TestAccountOnlineQueries(t *testing.T) { require.NoError(t, err) defer tx.Rollback() + arw := store.NewAccountsSQLReaderWriter(tx) + var accts map[basics.Address]basics.AccountData accountsInitTest(t, tx, accts, protocol.ConsensusCurrentVersion) - totals, err := accountsTotals(context.Background(), tx, false) + totals, err := arw.AccountsTotals(context.Background(), false) require.NoError(t, err) var baseAccounts lruAccounts @@ -2434,7 +2440,7 @@ func TestAccountOnlineQueries(t *testing.T) { err = updatesOnlineCnt.accountsLoadOld(tx) require.NoError(t, err) - err = accountsPutTotals(tx, totals, false) + err = arw.AccountsPutTotals(totals, false) require.NoError(t, err) updatedAccts, _, _, err := accountsNewRound(tx, updatesCnt, compactResourcesDeltas{}, nil, nil, proto, rnd) require.NoError(t, err) @@ -2444,7 +2450,7 @@ func TestAccountOnlineQueries(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, updatedOnlineAccts) - err = updateAccountsRound(tx, rnd) + err = arw.UpdateAccountsRound(rnd) require.NoError(t, err) } @@ -2457,7 +2463,7 @@ func TestAccountOnlineQueries(t *testing.T) { // check round 1 rnd := basics.Round(1) - online, err := accountsOnlineTop(tx, rnd, 0, 10, proto) + online, err := arw.AccountsOnlineTop(rnd, 0, 10, proto) require.NoError(t, err) require.Equal(t, 2, len(online)) require.NotContains(t, online, addrC) @@ -2496,7 +2502,7 @@ func TestAccountOnlineQueries(t *testing.T) { // check round 2 rnd = basics.Round(2) - online, err = accountsOnlineTop(tx, rnd, 0, 10, proto) + online, err = arw.AccountsOnlineTop(rnd, 0, 10, proto) require.NoError(t, err) require.Equal(t, 1, len(online)) require.NotContains(t, online, addrA) @@ -2529,7 +2535,7 @@ func TestAccountOnlineQueries(t *testing.T) { // check round 3 rnd = basics.Round(3) - online, err = accountsOnlineTop(tx, rnd, 0, 10, proto) + online, err = arw.AccountsOnlineTop(rnd, 0, 10, proto) require.NoError(t, err) require.Equal(t, 1, len(online)) require.NotContains(t, online, addrA) @@ -2560,7 +2566,7 @@ func TestAccountOnlineQueries(t *testing.T) { require.Equal(t, dataC3.AccountBaseData.MicroAlgos, paod.AccountData.MicroAlgos) require.Equal(t, voteIDC, paod.AccountData.VoteID) - paods, err := onlineAccountsAll(tx, 0) + paods, err := arw.OnlineAccountsAll(0) require.NoError(t, err) require.Equal(t, 5, len(paods)) @@ -2602,20 +2608,20 @@ func TestAccountOnlineQueries(t *testing.T) { checkAddrC() checkAddrA() - paods, err = onlineAccountsAll(tx, 3) + paods, err = arw.OnlineAccountsAll(3) require.NoError(t, err) require.Equal(t, 5, len(paods)) checkAddrB() checkAddrC() checkAddrA() - paods, err = onlineAccountsAll(tx, 2) + paods, err = arw.OnlineAccountsAll(2) require.NoError(t, err) require.Equal(t, 3, len(paods)) checkAddrB() checkAddrC() - paods, err = onlineAccountsAll(tx, 1) + paods, err = arw.OnlineAccountsAll(1) require.NoError(t, err) require.Equal(t, 2, len(paods)) checkAddrB() @@ -2880,70 +2886,6 @@ func TestAccountOnlineRoundParams(t *testing.T) { require.Equal(t, maxRounds, int(endRound)) } -func TestRowidsToChunkedArgs(t *testing.T) { - partitiontest.PartitionTest(t) - - res := rowidsToChunkedArgs([]int64{1}) - require.Equal(t, 1, cap(res)) - require.Equal(t, 1, len(res)) - require.Equal(t, 1, cap(res[0])) - require.Equal(t, 1, len(res[0])) - require.Equal(t, []interface{}{int64(1)}, res[0]) - - input := make([]int64, 999) - for i := 0; i < len(input); i++ { - input[i] = int64(i) - } - res = rowidsToChunkedArgs(input) - require.Equal(t, 1, cap(res)) - require.Equal(t, 1, len(res)) - require.Equal(t, 999, cap(res[0])) - require.Equal(t, 999, len(res[0])) - for i := 0; i < len(input); i++ { - require.Equal(t, interface{}(int64(i)), res[0][i]) - } - - input = make([]int64, 1001) - for i := 0; i < len(input); i++ { - input[i] = int64(i) - } - res = rowidsToChunkedArgs(input) - require.Equal(t, 2, cap(res)) - require.Equal(t, 2, len(res)) - require.Equal(t, 999, cap(res[0])) - require.Equal(t, 999, len(res[0])) - require.Equal(t, 2, cap(res[1])) - require.Equal(t, 2, len(res[1])) - for i := 0; i < 999; i++ { - require.Equal(t, interface{}(int64(i)), res[0][i]) - } - j := 0 - for i := 999; i < len(input); i++ { - require.Equal(t, interface{}(int64(i)), res[1][j]) - j++ - } - - input = make([]int64, 2*999) - for i := 0; i < len(input); i++ { - input[i] = int64(i) - } - res = rowidsToChunkedArgs(input) - require.Equal(t, 2, cap(res)) - require.Equal(t, 2, len(res)) - require.Equal(t, 999, cap(res[0])) - require.Equal(t, 999, len(res[0])) - require.Equal(t, 999, cap(res[1])) - require.Equal(t, 999, len(res[1])) - for i := 0; i < 999; i++ { - require.Equal(t, interface{}(int64(i)), res[0][i]) - } - j = 0 - for i := 999; i < len(input); i++ { - require.Equal(t, interface{}(int64(i)), res[1][j]) - j++ - } -} - // TestAccountDBTxTailLoad checks txtailNewRound and loadTxTail delete and load right data func TestAccountDBTxTailLoad(t *testing.T) { partitiontest.PartitionTest(t) @@ -2957,6 +2899,8 @@ func TestAccountDBTxTailLoad(t *testing.T) { require.NoError(t, err) defer tx.Rollback() + arw := store.NewAccountsSQLReaderWriter(tx) + err = accountsCreateTxTailTable(context.Background(), tx) require.NoError(t, err) @@ -2966,14 +2910,14 @@ func TestAccountDBTxTailLoad(t *testing.T) { roundData := make([][]byte, 1500) const retainSize = 1001 for i := startRound; i <= endRound; i++ { - data := txTailRound{Hdr: bookkeeping.BlockHeader{TimeStamp: int64(i)}} + data := store.TxTailRound{Hdr: bookkeeping.BlockHeader{TimeStamp: int64(i)}} roundData[i-1] = protocol.Encode(&data) } forgetBefore := (endRound + 1).SubSaturate(retainSize) - err = txtailNewRound(context.Background(), tx, startRound, roundData, forgetBefore) + err = arw.TxtailNewRound(context.Background(), startRound, roundData, forgetBefore) require.NoError(t, err) - data, _, baseRound, err := loadTxTail(context.Background(), tx, endRound) + data, _, baseRound, err := arw.LoadTxTail(context.Background(), endRound) require.NoError(t, err) require.Len(t, data, retainSize) require.Equal(t, basics.Round(endRound-retainSize+1), baseRound) // 500...1500 @@ -3007,6 +2951,8 @@ func TestOnlineAccountsDeletion(t *testing.T) { var accts map[basics.Address]basics.AccountData accountsInitTest(t, tx, accts, protocol.ConsensusCurrentVersion) + arw := store.NewAccountsSQLReaderWriter(tx) + updates := compactOnlineAccountDeltas{} addrA := ledgertesting.RandomAddress() addrB := ledgertesting.RandomAddress() @@ -3066,7 +3012,7 @@ func TestOnlineAccountsDeletion(t *testing.T) { var history []store.PersistedOnlineAccountData var validThrough basics.Round for _, rnd := range []basics.Round{1, 2, 3} { - err = onlineAccountsDelete(tx, rnd) + err = arw.OnlineAccountsDelete(rnd) require.NoError(t, err) err = tx.QueryRow("SELECT COUNT(1) FROM onlineaccounts").Scan(&count) @@ -3084,7 +3030,7 @@ func TestOnlineAccountsDeletion(t *testing.T) { } for _, rnd := range []basics.Round{4, 5, 6, 7} { - err = onlineAccountsDelete(tx, rnd) + err = arw.OnlineAccountsDelete(rnd) require.NoError(t, err) err = tx.QueryRow("SELECT COUNT(1) FROM onlineaccounts").Scan(&count) @@ -3102,7 +3048,7 @@ func TestOnlineAccountsDeletion(t *testing.T) { } for _, rnd := range []basics.Round{8, 9} { - err = onlineAccountsDelete(tx, rnd) + err = arw.OnlineAccountsDelete(rnd) require.NoError(t, err) err = tx.QueryRow("SELECT COUNT(1) FROM onlineaccounts").Scan(&count) diff --git a/ledger/acctonline.go b/ledger/acctonline.go index c0cea58516..cd7e81edd7 100644 --- a/ledger/acctonline.go +++ b/ledger/acctonline.go @@ -152,6 +152,7 @@ func (ao *onlineAccounts) initializeFromDisk(l ledgerForTracker, lastBalancesRou ao.log = l.trackerLog() err = ao.dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) error { + arw := store.NewAccountsSQLReaderWriter(tx) var err0 error var endRound basics.Round ao.onlineRoundParamsData, endRound, err0 = accountsOnlineRoundParams(tx) @@ -162,7 +163,7 @@ func (ao *onlineAccounts) initializeFromDisk(l ledgerForTracker, lastBalancesRou return fmt.Errorf("last onlineroundparams round %d does not match dbround %d", endRound, ao.cachedDBRoundOnline) } - onlineAccounts, err0 := onlineAccountsAll(tx, onlineAccountsCacheMaxSize) + onlineAccounts, err0 := arw.OnlineAccountsAll(onlineAccountsCacheMaxSize) if err0 != nil { return err0 } @@ -421,7 +422,9 @@ func (ao *onlineAccounts) commitRound(ctx context.Context, tx *sql.Tx, dcc *defe return err } - err = onlineAccountsDelete(tx, dcc.onlineAccountsForgetBefore) + arw := store.NewAccountsSQLReaderWriter(tx) + + err = arw.OnlineAccountsDelete(dcc.onlineAccountsForgetBefore) if err != nil { return err } @@ -817,11 +820,12 @@ func (ao *onlineAccounts) TopOnlineAccounts(rnd basics.Round, voteRnd basics.Rou start := time.Now() ledgerAccountsonlinetopCount.Inc(nil) err = ao.dbs.Rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - accts, err = accountsOnlineTop(tx, rnd, batchOffset, batchSize, genesisProto) + arw := store.NewAccountsSQLReaderWriter(tx) + accts, err = arw.AccountsOnlineTop(rnd, batchOffset, batchSize, genesisProto) if err != nil { return } - dbRound, err = accountsRound(tx) + dbRound, err = arw.AccountsRound() return }) ledgerAccountsonlinetopMicros.AddMicrosecondsSince(start, nil) diff --git a/ledger/acctonline_test.go b/ledger/acctonline_test.go index c1e0c2fbde..b017ed2cf7 100644 --- a/ledger/acctonline_test.go +++ b/ledger/acctonline_test.go @@ -29,6 +29,7 @@ import ( "github.com/algorand/go-algorand/data/basics" "github.com/algorand/go-algorand/data/bookkeeping" "github.com/algorand/go-algorand/ledger/ledgercore" + "github.com/algorand/go-algorand/ledger/store" ledgertesting "github.com/algorand/go-algorand/ledger/testing" "github.com/algorand/go-algorand/protocol" "github.com/algorand/go-algorand/test/partitiontest" @@ -81,6 +82,7 @@ func commitSyncPartial(t *testing.T, oa *onlineAccounts, ml *mockLedgerForTracke require.NoError(t, err) } err := ml.trackers.dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { + arw := store.NewAccountsSQLReaderWriter(tx) for _, lt := range ml.trackers.trackers { err0 := lt.commitRound(ctx, tx, dcc) if err0 != nil { @@ -88,7 +90,7 @@ func commitSyncPartial(t *testing.T, oa *onlineAccounts, ml *mockLedgerForTracke } } - return updateAccountsRound(tx, newBase) + return arw.UpdateAccountsRound(newBase) }) require.NoError(t, err) }() diff --git a/ledger/acctupdates.go b/ledger/acctupdates.go index 9e1d294ef9..1ce835274c 100644 --- a/ledger/acctupdates.go +++ b/ledger/acctupdates.go @@ -935,7 +935,8 @@ func (au *accountUpdates) initializeFromDisk(l ledgerForTracker, lastBalancesRou start := time.Now() ledgerAccountsinitCount.Inc(nil) err = au.dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) error { - totals, err0 := accountsTotals(ctx, tx, false) + arw := store.NewAccountsSQLReaderWriter(tx) + totals, err0 := arw.AccountsTotals(ctx, false) if err0 != nil { return err0 } @@ -1688,7 +1689,9 @@ func (au *accountUpdates) commitRound(ctx context.Context, tx *sql.Tx, dcc *defe dcc.stats.OldAccountPreloadDuration = time.Duration(time.Now().UnixNano()) - dcc.stats.OldAccountPreloadDuration } - err = accountsPutTotals(tx, dcc.roundTotals, false) + arw := store.NewAccountsSQLReaderWriter(tx) + + err = arw.AccountsPutTotals(dcc.roundTotals, false) if err != nil { return err } diff --git a/ledger/acctupdates_test.go b/ledger/acctupdates_test.go index f6e954d7e0..a55676b494 100644 --- a/ledger/acctupdates_test.go +++ b/ledger/acctupdates_test.go @@ -563,7 +563,8 @@ func TestAcctUpdates(t *testing.T) { // check the account totals. var dbRound basics.Round err := ml.dbs.Rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - dbRound, err = accountsRound(tx) + arw := store.NewAccountsSQLReaderWriter(tx) + dbRound, err = arw.AccountsRound() return }) require.NoError(t, err) @@ -576,7 +577,8 @@ func TestAcctUpdates(t *testing.T) { expectedTotals := ledgertesting.CalculateNewRoundAccountTotals(t, updates, rewardsLevels[dbRound], proto, nil, ledgercore.AccountTotals{}) var actualTotals ledgercore.AccountTotals err = ml.dbs.Rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - actualTotals, err = accountsTotals(ctx, tx, false) + arw := store.NewAccountsSQLReaderWriter(tx) + actualTotals, err = arw.AccountsTotals(ctx, false) return }) require.NoError(t, err) @@ -2439,11 +2441,12 @@ func TestAcctUpdatesResources(t *testing.T) { err := au.prepareCommit(dcc) require.NoError(t, err) err = ml.trackers.dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { + arw := store.NewAccountsSQLReaderWriter(tx) err = au.commitRound(ctx, tx, dcc) if err != nil { return err } - err = updateAccountsRound(tx, newBase) + err = arw.UpdateAccountsRound(newBase) return err }) require.NoError(t, err) @@ -2732,11 +2735,12 @@ func auCommitSync(t *testing.T, rnd basics.Round, au *accountUpdates, ml *mockLe err := au.prepareCommit(dcc) require.NoError(t, err) err = ml.trackers.dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { + arw := store.NewAccountsSQLReaderWriter(tx) err = au.commitRound(ctx, tx, dcc) if err != nil { return err } - err = updateAccountsRound(tx, newBase) + err = arw.UpdateAccountsRound(newBase) return err }) require.NoError(t, err) diff --git a/ledger/catchpointtracker.go b/ledger/catchpointtracker.go index d07285677c..c8f1ddbf66 100644 --- a/ledger/catchpointtracker.go +++ b/ledger/catchpointtracker.go @@ -237,14 +237,14 @@ func (ct *catchpointTracker) finishFirstStage(ctx context.Context, dbRound basic } f := func(ctx context.Context, tx *sql.Tx) error { - cps := store.NewCatchpointSQLReaderWriter(tx) + crw := store.NewCatchpointSQLReaderWriter(tx) err := ct.recordFirstStageInfo(ctx, tx, dbRound, totalKVs, totalAccounts, totalChunks, biggestChunkLen) if err != nil { return err } // Clear the db record. - return cps.WriteCatchpointStateUint64(ctx, catchpointStateWritingFirstStageInfo, 0) + return crw.WriteCatchpointStateUint64(ctx, catchpointStateWritingFirstStageInfo, 0) } return ct.dbs.Wdb.Atomic(f) } @@ -517,7 +517,7 @@ func (ct *catchpointTracker) commitRound(ctx context.Context, tx *sql.Tx, dcc *d } }() - cps := store.NewCatchpointSQLReaderWriter(tx) + crw := store.NewCatchpointSQLReaderWriter(tx) if ct.catchpointEnabled() { var mc *MerkleCommitter @@ -560,19 +560,19 @@ func (ct *catchpointTracker) commitRound(ctx context.Context, tx *sql.Tx, dcc *d } if dcc.catchpointFirstStage { - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateWritingFirstStageInfo, 1) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateWritingFirstStageInfo, 1) if err != nil { return err } } - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchpointLookback, dcc.catchpointLookback) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchpointLookback, dcc.catchpointLookback) if err != nil { return err } for _, round := range ct.calculateCatchpointRounds(dcc) { - err = cps.InsertUnfinishedCatchpoint(ctx, round, dcc.committedRoundDigests[round-dcc.oldBase-1]) + err = crw.InsertUnfinishedCatchpoint(ctx, round, dcc.committedRoundDigests[round-dcc.oldBase-1]) if err != nil { return err } @@ -797,12 +797,12 @@ func (ct *catchpointTracker) createCatchpoint(ctx context.Context, accountsRound } err = ct.dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - cps := store.NewCatchpointSQLReaderWriter(tx) + crw := store.NewCatchpointSQLReaderWriter(tx) err = ct.recordCatchpointFile(ctx, tx, round, relCatchpointFilePath, fileInfo.Size()) if err != nil { return err } - return cps.DeleteUnfinishedCatchpoint(ctx, round) + return crw.DeleteUnfinishedCatchpoint(ctx, round) }) if err != nil { return err @@ -1188,7 +1188,8 @@ func (ct *catchpointTracker) generateCatchpointData(ctx context.Context, account } func (ct *catchpointTracker) recordFirstStageInfo(ctx context.Context, tx *sql.Tx, accountsRound basics.Round, totalKVs uint64, totalAccounts uint64, totalChunks uint64, biggestChunkLen uint64) error { - accountTotals, err := accountsTotals(ctx, tx, false) + arw := store.NewAccountsSQLReaderWriter(tx) + accountTotals, err := arw.AccountsTotals(ctx, false) if err != nil { return err } @@ -1244,9 +1245,9 @@ func makeCatchpointFilePath(round basics.Round) string { // deleting 2 entries while inserting single entry allow us to adjust the size of the backing storage and have the // database and storage realign. func (ct *catchpointTracker) recordCatchpointFile(ctx context.Context, e db.Executable, round basics.Round, relCatchpointFilePath string, fileSize int64) (err error) { - cps := store.NewCatchpointSQLReaderWriter(e) + crw := store.NewCatchpointSQLReaderWriter(e) if ct.catchpointFileHistoryLength != 0 { - err = cps.StoreCatchpoint(ctx, round, relCatchpointFilePath, "", fileSize) + err = crw.StoreCatchpoint(ctx, round, relCatchpointFilePath, "", fileSize) if err != nil { ct.log.Warnf("catchpointTracker.recordCatchpointFile() unable to save catchpoint: %v", err) return @@ -1262,7 +1263,7 @@ func (ct *catchpointTracker) recordCatchpointFile(ctx context.Context, e db.Exec return } var filesToDelete map[basics.Round]string - filesToDelete, err = cps.GetOldestCatchpointFiles(ctx, 2, ct.catchpointFileHistoryLength) + filesToDelete, err = crw.GetOldestCatchpointFiles(ctx, 2, ct.catchpointFileHistoryLength) if err != nil { return fmt.Errorf("unable to delete catchpoint file, getOldestCatchpointFiles failed : %v", err) } @@ -1271,7 +1272,7 @@ func (ct *catchpointTracker) recordCatchpointFile(ctx context.Context, e db.Exec if err != nil { return err } - err = cps.StoreCatchpoint(ctx, round, "", "", 0) + err = crw.StoreCatchpoint(ctx, round, "", "", 0) if err != nil { return fmt.Errorf("unable to delete old catchpoint entry '%s' : %v", fileToDelete, err) } @@ -1288,8 +1289,8 @@ func (ct *catchpointTracker) GetCatchpointStream(round basics.Round) (ReadCloseS // TODO: we need to generalize this, check @cce PoC PR, he has something // somewhat broken for some KVs.. err := ct.dbs.Rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - cps := store.NewCatchpointSQLReaderWriter(tx) - dbFileName, _, fileSize, err = cps.GetCatchpoint(ctx, round) + crw := store.NewCatchpointSQLReaderWriter(tx) + dbFileName, _, fileSize, err = crw.GetCatchpoint(ctx, round) return }) ledgerGetcatchpointMicros.AddMicrosecondsSince(start, nil) @@ -1347,10 +1348,10 @@ func (ct *catchpointTracker) GetCatchpointStream(round basics.Round) (ReadCloseS // deleteStoredCatchpoints iterates over the storedcatchpoints table and deletes all the files stored on disk. // once all the files have been deleted, it would go ahead and remove the entries from the table. func deleteStoredCatchpoints(ctx context.Context, e db.Executable, dbDirectory string) (err error) { - cps := store.NewCatchpointSQLReaderWriter(e) + crw := store.NewCatchpointSQLReaderWriter(e) catchpointsFilesChunkSize := 50 for { - fileNames, err := cps.GetOldestCatchpointFiles(ctx, catchpointsFilesChunkSize, 0) + fileNames, err := crw.GetOldestCatchpointFiles(ctx, catchpointsFilesChunkSize, 0) if err != nil { return err } @@ -1364,7 +1365,7 @@ func deleteStoredCatchpoints(ctx context.Context, e db.Executable, dbDirectory s return err } // clear the entry from the database - err = cps.StoreCatchpoint(ctx, round, "", "", 0) + err = crw.StoreCatchpoint(ctx, round, "", "", 0) if err != nil { return err } @@ -1536,7 +1537,8 @@ func (ct *catchpointTracker) catchpointEnabled() bool { // initializeHashes initializes account/resource/kv hashes. // as part of the initialization, it tests if a hash table matches to account base and updates the former. func (ct *catchpointTracker) initializeHashes(ctx context.Context, tx *sql.Tx, rnd basics.Round) error { - hashRound, err := accountsHashRound(ctx, tx) + arw := store.NewAccountsSQLReaderWriter(tx) + hashRound, err := arw.AccountsHashRound(ctx) if err != nil { return err } diff --git a/ledger/catchpointwriter.go b/ledger/catchpointwriter.go index c7f87961b2..e204a8ae74 100644 --- a/ledger/catchpointwriter.go +++ b/ledger/catchpointwriter.go @@ -28,6 +28,7 @@ import ( "github.com/algorand/msgp/msgp" "github.com/algorand/go-algorand/data/basics" + "github.com/algorand/go-algorand/ledger/store" "github.com/algorand/go-algorand/protocol" ) @@ -130,12 +131,14 @@ func (chunk catchpointFileChunkV6) empty() bool { } func makeCatchpointWriter(ctx context.Context, filePath string, tx *sql.Tx, maxResourcesPerChunk int) (*catchpointWriter, error) { - totalAccounts, err := totalAccounts(ctx, tx) + arw := store.NewAccountsSQLReaderWriter(tx) + + totalAccounts, err := arw.TotalAccounts(ctx) if err != nil { return nil, err } - totalKVs, err := totalKVs(ctx, tx) + totalKVs, err := arw.TotalKVs(ctx) if err != nil { return nil, err } diff --git a/ledger/catchpointwriter_test.go b/ledger/catchpointwriter_test.go index 5b7563b6e7..a74e907684 100644 --- a/ledger/catchpointwriter_test.go +++ b/ledger/catchpointwriter_test.go @@ -40,6 +40,7 @@ import ( "github.com/algorand/go-algorand/data/transactions/logic" "github.com/algorand/go-algorand/data/txntest" "github.com/algorand/go-algorand/ledger/ledgercore" + "github.com/algorand/go-algorand/ledger/store" ledgertesting "github.com/algorand/go-algorand/ledger/testing" "github.com/algorand/go-algorand/logging" "github.com/algorand/go-algorand/protocol" @@ -195,6 +196,8 @@ func testWriteCatchpoint(t *testing.T, rdb db.Accessor, datapath string, filepat err := rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { writer, err := makeCatchpointWriter(context.Background(), datapath, tx, maxResourcesPerChunk) + arw := store.NewAccountsSQLReaderWriter(tx) + if err != nil { return err } @@ -208,11 +211,11 @@ func testWriteCatchpoint(t *testing.T, rdb db.Accessor, datapath string, filepat totalAccounts = writer.totalAccounts totalChunks = writer.chunkNum biggestChunkLen = writer.biggestChunkLen - accountsRnd, err = accountsRound(tx) + accountsRnd, err = arw.AccountsRound() if err != nil { return } - totals, err = accountsTotals(ctx, tx, false) + totals, err = arw.AccountsTotals(ctx, false) return }) require.NoError(t, err) @@ -371,7 +374,8 @@ func TestCatchpointReadDatabaseOverflowAccounts(t *testing.T) { readDb := ml.trackerDB().Rdb err = readDb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - expectedTotalAccounts, err := totalAccounts(ctx, tx) + arw := store.NewAccountsSQLReaderWriter(tx) + expectedTotalAccounts, err := arw.TotalAccounts(ctx) if err != nil { return err } diff --git a/ledger/catchupaccessor.go b/ledger/catchupaccessor.go index 64860d4411..9e5371104b 100644 --- a/ledger/catchupaccessor.go +++ b/ledger/catchupaccessor.go @@ -243,27 +243,27 @@ func (c *catchpointCatchupAccessorImpl) ResetStagingBalances(ctx context.Context start := time.Now() ledgerResetstagingbalancesCount.Inc(nil) err = wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - cps := store.NewCatchpointSQLReaderWriter(tx) + crw := store.NewCatchpointSQLReaderWriter(tx) err = resetCatchpointStagingBalances(ctx, tx, newCatchup) if err != nil { return fmt.Errorf("unable to reset catchpoint catchup balances : %v", err) } if !newCatchup { - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound, 0) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound, 0) if err != nil { return err } - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBlockRound, 0) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBlockRound, 0) if err != nil { return err } - err = cps.WriteCatchpointStateString(ctx, catchpointStateCatchupLabel, "") + err = crw.WriteCatchpointStateString(ctx, catchpointStateCatchupLabel, "") if err != nil { return err } - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupState, 0) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupState, 0) if err != nil { return fmt.Errorf("unable to write catchpoint catchup state '%s': %v", catchpointStateCatchupState, err) } @@ -333,18 +333,20 @@ func (c *catchpointCatchupAccessorImpl) processStagingContent(ctx context.Contex start := time.Now() ledgerProcessstagingcontentCount.Inc(nil) err = wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - cps := store.NewCatchpointSQLReaderWriter(tx) - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBlockRound, uint64(fileHeader.BlocksRound)) + crw := store.NewCatchpointSQLReaderWriter(tx) + arw := store.NewAccountsSQLReaderWriter(tx) + + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBlockRound, uint64(fileHeader.BlocksRound)) if err != nil { return fmt.Errorf("CatchpointCatchupAccessorImpl::processStagingContent: unable to write catchpoint catchup state '%s': %v", catchpointStateCatchupBlockRound, err) } if fileHeader.Version == CatchpointFileVersionV6 { - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupHashRound, uint64(fileHeader.BlocksRound)) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupHashRound, uint64(fileHeader.BlocksRound)) if err != nil { return fmt.Errorf("CatchpointCatchupAccessorImpl::processStagingContent: unable to write catchpoint catchup state '%s': %v", catchpointStateCatchupHashRound, err) } } - err = accountsPutTotals(tx, fileHeader.Totals, true) + err = arw.AccountsPutTotals(fileHeader.Totals, true) return }) ledgerProcessstagingcontentMicros.AddMicrosecondsSince(start, nil) @@ -829,6 +831,7 @@ func (c *catchpointCatchupAccessorImpl) VerifyCatchpoint(ctx context.Context, bl start := time.Now() ledgerVerifycatchpointCount.Inc(nil) err = rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { + arw := store.NewAccountsSQLReaderWriter(tx) // create the merkle trie for the balances mc, err0 := MakeMerkleCommitter(tx, true) if err0 != nil { @@ -845,7 +848,7 @@ func (c *catchpointCatchupAccessorImpl) VerifyCatchpoint(ctx context.Context, bl return fmt.Errorf("unable to get trie root hash: %v", err) } - totals, err = accountsTotals(ctx, tx, true) + totals, err = arw.AccountsTotals(ctx, true) if err != nil { return fmt.Errorf("unable to get accounts totals: %v", err) } @@ -881,8 +884,8 @@ func (c *catchpointCatchupAccessorImpl) StoreBalancesRound(ctx context.Context, start := time.Now() ledgerStorebalancesroundCount.Inc(nil) err = wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - cps := store.NewCatchpointSQLReaderWriter(tx) - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound, uint64(balancesRound)) + crw := store.NewCatchpointSQLReaderWriter(tx) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound, uint64(balancesRound)) if err != nil { return fmt.Errorf("CatchpointCatchupAccessorImpl::StoreBalancesRound: unable to write catchpoint catchup state '%s': %v", catchpointStateCatchupBalancesRound, err) } @@ -978,22 +981,23 @@ func (c *catchpointCatchupAccessorImpl) finishBalances(ctx context.Context) (err start := time.Now() ledgerCatchpointFinishBalsCount.Inc(nil) err = wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - cps := store.NewCatchpointSQLReaderWriter(tx) + crw := store.NewCatchpointSQLReaderWriter(tx) + arw := store.NewAccountsSQLReaderWriter(tx) var balancesRound, hashRound uint64 var totals ledgercore.AccountTotals - balancesRound, err = cps.ReadCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound) + balancesRound, err = crw.ReadCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound) if err != nil { return err } - hashRound, err = cps.ReadCatchpointStateUint64(ctx, catchpointStateCatchupHashRound) + hashRound, err = crw.ReadCatchpointStateUint64(ctx, catchpointStateCatchupHashRound) if err != nil { return err } - totals, err = accountsTotals(ctx, tx, true) + totals, err = arw.AccountsTotals(ctx, true) if err != nil { return err } @@ -1036,7 +1040,7 @@ func (c *catchpointCatchupAccessorImpl) finishBalances(ctx context.Context) (err return err } - err = accountsPutTotals(tx, totals, false) + err = arw.AccountsPutTotals(totals, false) if err != nil { return err } @@ -1046,29 +1050,29 @@ func (c *catchpointCatchupAccessorImpl) finishBalances(ctx context.Context) (err return err } - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound, 0) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBalancesRound, 0) if err != nil { return err } - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBlockRound, 0) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupBlockRound, 0) if err != nil { return err } - err = cps.WriteCatchpointStateString(ctx, catchpointStateCatchupLabel, "") + err = crw.WriteCatchpointStateString(ctx, catchpointStateCatchupLabel, "") if err != nil { return err } if hashRound != 0 { - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupHashRound, 0) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupHashRound, 0) if err != nil { return err } } - err = cps.WriteCatchpointStateUint64(ctx, catchpointStateCatchupState, 0) + err = crw.WriteCatchpointStateUint64(ctx, catchpointStateCatchupState, 0) if err != nil { return fmt.Errorf("unable to write catchpoint catchup state '%s': %v", catchpointStateCatchupState, err) } diff --git a/ledger/msgp_gen.go b/ledger/msgp_gen.go index 5ad3f23d9a..3526e36622 100644 --- a/ledger/msgp_gen.go +++ b/ledger/msgp_gen.go @@ -6,9 +6,6 @@ import ( "sort" "github.com/algorand/msgp/msgp" - - "github.com/algorand/go-algorand/data/basics" - "github.com/algorand/go-algorand/data/transactions" ) // The following msgp objects are implemented in this file: @@ -84,22 +81,6 @@ import ( // |-----> Msgsize // |-----> MsgIsZero // -// txTailRound -// |-----> (*) MarshalMsg -// |-----> (*) CanMarshalMsg -// |-----> (*) UnmarshalMsg -// |-----> (*) CanUnmarshalMsg -// |-----> (*) Msgsize -// |-----> (*) MsgIsZero -// -// txTailRoundLease -// |-----> (*) MarshalMsg -// |-----> (*) CanMarshalMsg -// |-----> (*) UnmarshalMsg -// |-----> (*) CanUnmarshalMsg -// |-----> (*) Msgsize -// |-----> (*) MsgIsZero -// // MarshalMsg implements msgp.Marshaler func (z CatchpointCatchupState) MarshalMsg(b []byte) (o []byte) { @@ -2027,459 +2008,3 @@ func (z hashKind) Msgsize() (s int) { func (z hashKind) MsgIsZero() bool { return z == 0 } - -// MarshalMsg implements msgp.Marshaler -func (z *txTailRound) MarshalMsg(b []byte) (o []byte) { - o = msgp.Require(b, z.Msgsize()) - // omitempty: check for empty values - zb0004Len := uint32(4) - var zb0004Mask uint8 /* 5 bits */ - if (*z).Hdr.MsgIsZero() { - zb0004Len-- - zb0004Mask |= 0x2 - } - if len((*z).TxnIDs) == 0 { - zb0004Len-- - zb0004Mask |= 0x4 - } - if len((*z).Leases) == 0 { - zb0004Len-- - zb0004Mask |= 0x8 - } - if len((*z).LastValid) == 0 { - zb0004Len-- - zb0004Mask |= 0x10 - } - // variable map header, size zb0004Len - o = append(o, 0x80|uint8(zb0004Len)) - if zb0004Len != 0 { - if (zb0004Mask & 0x2) == 0 { // if not empty - // string "h" - o = append(o, 0xa1, 0x68) - o = (*z).Hdr.MarshalMsg(o) - } - if (zb0004Mask & 0x4) == 0 { // if not empty - // string "i" - o = append(o, 0xa1, 0x69) - if (*z).TxnIDs == nil { - o = msgp.AppendNil(o) - } else { - o = msgp.AppendArrayHeader(o, uint32(len((*z).TxnIDs))) - } - for zb0001 := range (*z).TxnIDs { - o = (*z).TxnIDs[zb0001].MarshalMsg(o) - } - } - if (zb0004Mask & 0x8) == 0 { // if not empty - // string "l" - o = append(o, 0xa1, 0x6c) - if (*z).Leases == nil { - o = msgp.AppendNil(o) - } else { - o = msgp.AppendArrayHeader(o, uint32(len((*z).Leases))) - } - for zb0003 := range (*z).Leases { - o = (*z).Leases[zb0003].MarshalMsg(o) - } - } - if (zb0004Mask & 0x10) == 0 { // if not empty - // string "v" - o = append(o, 0xa1, 0x76) - if (*z).LastValid == nil { - o = msgp.AppendNil(o) - } else { - o = msgp.AppendArrayHeader(o, uint32(len((*z).LastValid))) - } - for zb0002 := range (*z).LastValid { - o = (*z).LastValid[zb0002].MarshalMsg(o) - } - } - } - return -} - -func (_ *txTailRound) CanMarshalMsg(z interface{}) bool { - _, ok := (z).(*txTailRound) - return ok -} - -// UnmarshalMsg implements msgp.Unmarshaler -func (z *txTailRound) UnmarshalMsg(bts []byte) (o []byte, err error) { - var field []byte - _ = field - var zb0004 int - var zb0005 bool - zb0004, zb0005, bts, err = msgp.ReadMapHeaderBytes(bts) - if _, ok := err.(msgp.TypeError); ok { - zb0004, zb0005, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - if zb0004 > 0 { - zb0004-- - var zb0006 int - var zb0007 bool - zb0006, zb0007, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "TxnIDs") - return - } - if zb0007 { - (*z).TxnIDs = nil - } else if (*z).TxnIDs != nil && cap((*z).TxnIDs) >= zb0006 { - (*z).TxnIDs = ((*z).TxnIDs)[:zb0006] - } else { - (*z).TxnIDs = make([]transactions.Txid, zb0006) - } - for zb0001 := range (*z).TxnIDs { - bts, err = (*z).TxnIDs[zb0001].UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "TxnIDs", zb0001) - return - } - } - } - if zb0004 > 0 { - zb0004-- - var zb0008 int - var zb0009 bool - zb0008, zb0009, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "LastValid") - return - } - if zb0009 { - (*z).LastValid = nil - } else if (*z).LastValid != nil && cap((*z).LastValid) >= zb0008 { - (*z).LastValid = ((*z).LastValid)[:zb0008] - } else { - (*z).LastValid = make([]basics.Round, zb0008) - } - for zb0002 := range (*z).LastValid { - bts, err = (*z).LastValid[zb0002].UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "LastValid", zb0002) - return - } - } - } - if zb0004 > 0 { - zb0004-- - var zb0010 int - var zb0011 bool - zb0010, zb0011, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "Leases") - return - } - if zb0011 { - (*z).Leases = nil - } else if (*z).Leases != nil && cap((*z).Leases) >= zb0010 { - (*z).Leases = ((*z).Leases)[:zb0010] - } else { - (*z).Leases = make([]txTailRoundLease, zb0010) - } - for zb0003 := range (*z).Leases { - bts, err = (*z).Leases[zb0003].UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "Leases", zb0003) - return - } - } - } - if zb0004 > 0 { - zb0004-- - bts, err = (*z).Hdr.UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "Hdr") - return - } - } - if zb0004 > 0 { - err = msgp.ErrTooManyArrayFields(zb0004) - if err != nil { - err = msgp.WrapError(err, "struct-from-array") - return - } - } - } else { - if err != nil { - err = msgp.WrapError(err) - return - } - if zb0005 { - (*z) = txTailRound{} - } - for zb0004 > 0 { - zb0004-- - field, bts, err = msgp.ReadMapKeyZC(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - switch string(field) { - case "i": - var zb0012 int - var zb0013 bool - zb0012, zb0013, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "TxnIDs") - return - } - if zb0013 { - (*z).TxnIDs = nil - } else if (*z).TxnIDs != nil && cap((*z).TxnIDs) >= zb0012 { - (*z).TxnIDs = ((*z).TxnIDs)[:zb0012] - } else { - (*z).TxnIDs = make([]transactions.Txid, zb0012) - } - for zb0001 := range (*z).TxnIDs { - bts, err = (*z).TxnIDs[zb0001].UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "TxnIDs", zb0001) - return - } - } - case "v": - var zb0014 int - var zb0015 bool - zb0014, zb0015, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "LastValid") - return - } - if zb0015 { - (*z).LastValid = nil - } else if (*z).LastValid != nil && cap((*z).LastValid) >= zb0014 { - (*z).LastValid = ((*z).LastValid)[:zb0014] - } else { - (*z).LastValid = make([]basics.Round, zb0014) - } - for zb0002 := range (*z).LastValid { - bts, err = (*z).LastValid[zb0002].UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "LastValid", zb0002) - return - } - } - case "l": - var zb0016 int - var zb0017 bool - zb0016, zb0017, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err, "Leases") - return - } - if zb0017 { - (*z).Leases = nil - } else if (*z).Leases != nil && cap((*z).Leases) >= zb0016 { - (*z).Leases = ((*z).Leases)[:zb0016] - } else { - (*z).Leases = make([]txTailRoundLease, zb0016) - } - for zb0003 := range (*z).Leases { - bts, err = (*z).Leases[zb0003].UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "Leases", zb0003) - return - } - } - case "h": - bts, err = (*z).Hdr.UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "Hdr") - return - } - default: - err = msgp.ErrNoField(string(field)) - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - } - o = bts - return -} - -func (_ *txTailRound) CanUnmarshalMsg(z interface{}) bool { - _, ok := (z).(*txTailRound) - return ok -} - -// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message -func (z *txTailRound) Msgsize() (s int) { - s = 1 + 2 + msgp.ArrayHeaderSize - for zb0001 := range (*z).TxnIDs { - s += (*z).TxnIDs[zb0001].Msgsize() - } - s += 2 + msgp.ArrayHeaderSize - for zb0002 := range (*z).LastValid { - s += (*z).LastValid[zb0002].Msgsize() - } - s += 2 + msgp.ArrayHeaderSize - for zb0003 := range (*z).Leases { - s += (*z).Leases[zb0003].Msgsize() - } - s += 2 + (*z).Hdr.Msgsize() - return -} - -// MsgIsZero returns whether this is a zero value -func (z *txTailRound) MsgIsZero() bool { - return (len((*z).TxnIDs) == 0) && (len((*z).LastValid) == 0) && (len((*z).Leases) == 0) && ((*z).Hdr.MsgIsZero()) -} - -// MarshalMsg implements msgp.Marshaler -func (z *txTailRoundLease) MarshalMsg(b []byte) (o []byte) { - o = msgp.Require(b, z.Msgsize()) - // omitempty: check for empty values - zb0002Len := uint32(3) - var zb0002Mask uint8 /* 4 bits */ - if (*z).TxnIdx == 0 { - zb0002Len-- - zb0002Mask |= 0x1 - } - if (*z).Lease == ([32]byte{}) { - zb0002Len-- - zb0002Mask |= 0x4 - } - if (*z).Sender.MsgIsZero() { - zb0002Len-- - zb0002Mask |= 0x8 - } - // variable map header, size zb0002Len - o = append(o, 0x80|uint8(zb0002Len)) - if zb0002Len != 0 { - if (zb0002Mask & 0x1) == 0 { // if not empty - // string "TxnIdx" - o = append(o, 0xa6, 0x54, 0x78, 0x6e, 0x49, 0x64, 0x78) - o = msgp.AppendUint64(o, (*z).TxnIdx) - } - if (zb0002Mask & 0x4) == 0 { // if not empty - // string "l" - o = append(o, 0xa1, 0x6c) - o = msgp.AppendBytes(o, ((*z).Lease)[:]) - } - if (zb0002Mask & 0x8) == 0 { // if not empty - // string "s" - o = append(o, 0xa1, 0x73) - o = (*z).Sender.MarshalMsg(o) - } - } - return -} - -func (_ *txTailRoundLease) CanMarshalMsg(z interface{}) bool { - _, ok := (z).(*txTailRoundLease) - return ok -} - -// UnmarshalMsg implements msgp.Unmarshaler -func (z *txTailRoundLease) UnmarshalMsg(bts []byte) (o []byte, err error) { - var field []byte - _ = field - var zb0002 int - var zb0003 bool - zb0002, zb0003, bts, err = msgp.ReadMapHeaderBytes(bts) - if _, ok := err.(msgp.TypeError); ok { - zb0002, zb0003, bts, err = msgp.ReadArrayHeaderBytes(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - if zb0002 > 0 { - zb0002-- - bts, err = (*z).Sender.UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "Sender") - return - } - } - if zb0002 > 0 { - zb0002-- - bts, err = msgp.ReadExactBytes(bts, ((*z).Lease)[:]) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "Lease") - return - } - } - if zb0002 > 0 { - zb0002-- - (*z).TxnIdx, bts, err = msgp.ReadUint64Bytes(bts) - if err != nil { - err = msgp.WrapError(err, "struct-from-array", "TxnIdx") - return - } - } - if zb0002 > 0 { - err = msgp.ErrTooManyArrayFields(zb0002) - if err != nil { - err = msgp.WrapError(err, "struct-from-array") - return - } - } - } else { - if err != nil { - err = msgp.WrapError(err) - return - } - if zb0003 { - (*z) = txTailRoundLease{} - } - for zb0002 > 0 { - zb0002-- - field, bts, err = msgp.ReadMapKeyZC(bts) - if err != nil { - err = msgp.WrapError(err) - return - } - switch string(field) { - case "s": - bts, err = (*z).Sender.UnmarshalMsg(bts) - if err != nil { - err = msgp.WrapError(err, "Sender") - return - } - case "l": - bts, err = msgp.ReadExactBytes(bts, ((*z).Lease)[:]) - if err != nil { - err = msgp.WrapError(err, "Lease") - return - } - case "TxnIdx": - (*z).TxnIdx, bts, err = msgp.ReadUint64Bytes(bts) - if err != nil { - err = msgp.WrapError(err, "TxnIdx") - return - } - default: - err = msgp.ErrNoField(string(field)) - if err != nil { - err = msgp.WrapError(err) - return - } - } - } - } - o = bts - return -} - -func (_ *txTailRoundLease) CanUnmarshalMsg(z interface{}) bool { - _, ok := (z).(*txTailRoundLease) - return ok -} - -// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message -func (z *txTailRoundLease) Msgsize() (s int) { - s = 1 + 2 + (*z).Sender.Msgsize() + 2 + msgp.ArrayHeaderSize + (32 * (msgp.ByteSize)) + 7 + msgp.Uint64Size - return -} - -// MsgIsZero returns whether this is a zero value -func (z *txTailRoundLease) MsgIsZero() bool { - return ((*z).Sender.MsgIsZero()) && ((*z).Lease == ([32]byte{})) && ((*z).TxnIdx == 0) -} diff --git a/ledger/msgp_gen_test.go b/ledger/msgp_gen_test.go index 2a03d80bc4..1a78cd429a 100644 --- a/ledger/msgp_gen_test.go +++ b/ledger/msgp_gen_test.go @@ -433,123 +433,3 @@ func BenchmarkUnmarshalencodedKVRecordV6(b *testing.B) { } } } - -func TestMarshalUnmarshaltxTailRound(t *testing.T) { - partitiontest.PartitionTest(t) - v := txTailRound{} - bts := v.MarshalMsg(nil) - left, err := v.UnmarshalMsg(bts) - if err != nil { - t.Fatal(err) - } - if len(left) > 0 { - t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) - } - - left, err = msgp.Skip(bts) - if err != nil { - t.Fatal(err) - } - if len(left) > 0 { - t.Errorf("%d bytes left over after Skip(): %q", len(left), left) - } -} - -func TestRandomizedEncodingtxTailRound(t *testing.T) { - protocol.RunEncodingTest(t, &txTailRound{}) -} - -func BenchmarkMarshalMsgtxTailRound(b *testing.B) { - v := txTailRound{} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - v.MarshalMsg(nil) - } -} - -func BenchmarkAppendMsgtxTailRound(b *testing.B) { - v := txTailRound{} - bts := make([]byte, 0, v.Msgsize()) - bts = v.MarshalMsg(bts[0:0]) - b.SetBytes(int64(len(bts))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - bts = v.MarshalMsg(bts[0:0]) - } -} - -func BenchmarkUnmarshaltxTailRound(b *testing.B) { - v := txTailRound{} - bts := v.MarshalMsg(nil) - b.ReportAllocs() - b.SetBytes(int64(len(bts))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := v.UnmarshalMsg(bts) - if err != nil { - b.Fatal(err) - } - } -} - -func TestMarshalUnmarshaltxTailRoundLease(t *testing.T) { - partitiontest.PartitionTest(t) - v := txTailRoundLease{} - bts := v.MarshalMsg(nil) - left, err := v.UnmarshalMsg(bts) - if err != nil { - t.Fatal(err) - } - if len(left) > 0 { - t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) - } - - left, err = msgp.Skip(bts) - if err != nil { - t.Fatal(err) - } - if len(left) > 0 { - t.Errorf("%d bytes left over after Skip(): %q", len(left), left) - } -} - -func TestRandomizedEncodingtxTailRoundLease(t *testing.T) { - protocol.RunEncodingTest(t, &txTailRoundLease{}) -} - -func BenchmarkMarshalMsgtxTailRoundLease(b *testing.B) { - v := txTailRoundLease{} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - v.MarshalMsg(nil) - } -} - -func BenchmarkAppendMsgtxTailRoundLease(b *testing.B) { - v := txTailRoundLease{} - bts := make([]byte, 0, v.Msgsize()) - bts = v.MarshalMsg(bts[0:0]) - b.SetBytes(int64(len(bts))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - bts = v.MarshalMsg(bts[0:0]) - } -} - -func BenchmarkUnmarshaltxTailRoundLease(b *testing.B) { - v := txTailRoundLease{} - bts := v.MarshalMsg(nil) - b.ReportAllocs() - b.SetBytes(int64(len(bts))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := v.UnmarshalMsg(bts) - if err != nil { - b.Fatal(err) - } - } -} diff --git a/ledger/store/accountsV2.go b/ledger/store/accountsV2.go new file mode 100644 index 0000000000..45b96eca66 --- /dev/null +++ b/ledger/store/accountsV2.go @@ -0,0 +1,423 @@ +// Copyright (C) 2019-2022 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package store + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "strings" + + "github.com/algorand/go-algorand/config" + "github.com/algorand/go-algorand/crypto" + "github.com/algorand/go-algorand/data/basics" + "github.com/algorand/go-algorand/ledger/ledgercore" + "github.com/algorand/go-algorand/protocol" + "github.com/algorand/go-algorand/util/db" +) + +type accountsV2Reader struct { + q db.Queryable +} + +type accountsV2Writer struct { + e db.Executable +} + +type accountsV2ReaderWriter struct { + accountsV2Reader + accountsV2Writer +} + +// NewAccountsSQLReaderWriter creates a Catchpoint SQL reader+writer +func NewAccountsSQLReaderWriter(e db.Executable) *accountsV2ReaderWriter { + return &accountsV2ReaderWriter{ + accountsV2Reader{q: e}, + accountsV2Writer{e: e}, + } +} + +// AccountsTotals returns account totals +func (r *accountsV2Reader) AccountsTotals(ctx context.Context, catchpointStaging bool) (totals ledgercore.AccountTotals, err error) { + id := "" + if catchpointStaging { + id = "catchpointStaging" + } + row := r.q.QueryRowContext(ctx, "SELECT online, onlinerewardunits, offline, offlinerewardunits, notparticipating, notparticipatingrewardunits, rewardslevel FROM accounttotals WHERE id=?", id) + err = row.Scan(&totals.Online.Money.Raw, &totals.Online.RewardUnits, + &totals.Offline.Money.Raw, &totals.Offline.RewardUnits, + &totals.NotParticipating.Money.Raw, &totals.NotParticipating.RewardUnits, + &totals.RewardsLevel) + + return +} + +// AccountsRound returns the tracker balances round number +func (r *accountsV2Reader) AccountsRound() (rnd basics.Round, err error) { + err = r.q.QueryRow("SELECT rnd FROM acctrounds WHERE id='acctbase'").Scan(&rnd) + if err != nil { + return + } + return +} + +// AccountsHashRound returns the round of the hash tree +// if the hash of the tree doesn't exists, it returns zero. +func (r *accountsV2Reader) AccountsHashRound(ctx context.Context) (hashrnd basics.Round, err error) { + err = r.q.QueryRowContext(ctx, "SELECT rnd FROM acctrounds WHERE id='hashbase'").Scan(&hashrnd) + if err == sql.ErrNoRows { + hashrnd = basics.Round(0) + err = nil + } + return +} + +// AccountsOnlineTop returns the top n online accounts starting at position offset +// (that is, the top offset'th account through the top offset+n-1'th account). +// +// The accounts are sorted by their normalized balance and address. The normalized +// balance has to do with the reward parts of online account balances. See the +// normalization procedure in AccountData.NormalizedOnlineBalance(). +// +// Note that this does not check if the accounts have a vote key valid for any +// particular round (past, present, or future). +func (r *accountsV2Reader) AccountsOnlineTop(rnd basics.Round, offset uint64, n uint64, proto config.ConsensusParams) (map[basics.Address]*ledgercore.OnlineAccount, error) { + // onlineaccounts has historical data ordered by updround for both online and offline accounts. + // This means some account A might have norm balance != 0 at round N and norm balance == 0 at some round K > N. + // For online top query one needs to find entries not fresher than X with norm balance != 0. + // To do that the query groups row by address and takes the latest updround, and then filters out rows with zero nor balance. + rows, err := r.q.Query(`SELECT address, normalizedonlinebalance, data, max(updround) FROM onlineaccounts +WHERE updround <= ? +GROUP BY address HAVING normalizedonlinebalance > 0 +ORDER BY normalizedonlinebalance DESC, address DESC LIMIT ? OFFSET ?`, rnd, n, offset) + + if err != nil { + return nil, err + } + defer rows.Close() + + res := make(map[basics.Address]*ledgercore.OnlineAccount, n) + for rows.Next() { + var addrbuf []byte + var buf []byte + var normBal sql.NullInt64 + var updround sql.NullInt64 + err = rows.Scan(&addrbuf, &normBal, &buf, &updround) + if err != nil { + return nil, err + } + + var data BaseOnlineAccountData + err = protocol.Decode(buf, &data) + if err != nil { + return nil, err + } + + var addr basics.Address + if len(addrbuf) != len(addr) { + err = fmt.Errorf("account DB address length mismatch: %d != %d", len(addrbuf), len(addr)) + return nil, err + } + + if !normBal.Valid { + return nil, fmt.Errorf("non valid norm balance for online account %s", addr.String()) + } + + copy(addr[:], addrbuf) + // TODO: figure out protocol to use for rewards + // The original implementation uses current proto to recalculate norm balance + // In the same time, in accountsNewRound genesis protocol is used to fill norm balance value + // In order to be consistent with the original implementation recalculate the balance with current proto + normBalance := basics.NormalizedOnlineAccountBalance(basics.Online, data.RewardsBase, data.MicroAlgos, proto) + oa := data.GetOnlineAccount(addr, normBalance) + res[addr] = &oa + } + + return res, rows.Err() +} + +// OnlineAccountsAll returns all online accounts +func (r *accountsV2Reader) OnlineAccountsAll(maxAccounts uint64) ([]PersistedOnlineAccountData, error) { + rows, err := r.q.Query("SELECT rowid, address, updround, data FROM onlineaccounts ORDER BY address, updround ASC") + if err != nil { + return nil, err + } + defer rows.Close() + + result := make([]PersistedOnlineAccountData, 0, maxAccounts) + var numAccounts uint64 + seenAddr := make([]byte, len(basics.Address{})) + for rows.Next() { + var addrbuf []byte + var buf []byte + data := PersistedOnlineAccountData{} + err := rows.Scan(&data.Rowid, &addrbuf, &data.UpdRound, &buf) + if err != nil { + return nil, err + } + if len(addrbuf) != len(data.Addr) { + err = fmt.Errorf("account DB address length mismatch: %d != %d", len(addrbuf), len(data.Addr)) + return nil, err + } + if maxAccounts > 0 { + if !bytes.Equal(seenAddr, addrbuf) { + numAccounts++ + if numAccounts > maxAccounts { + break + } + copy(seenAddr, addrbuf) + } + } + copy(data.Addr[:], addrbuf) + err = protocol.Decode(buf, &data.AccountData) + if err != nil { + return nil, err + } + result = append(result, data) + } + return result, nil +} + +// TotalAccounts returns the total number of accounts +func (r *accountsV2Reader) TotalAccounts(ctx context.Context) (total uint64, err error) { + err = r.q.QueryRowContext(ctx, "SELECT count(1) FROM accountbase").Scan(&total) + if err == sql.ErrNoRows { + total = 0 + err = nil + return + } + return +} + +// TotalKVs returns the total number of kv items +func (r *accountsV2Reader) TotalKVs(ctx context.Context) (total uint64, err error) { + err = r.q.QueryRowContext(ctx, "SELECT count(1) FROM kvstore").Scan(&total) + if err == sql.ErrNoRows { + total = 0 + err = nil + return + } + return +} + +// LoadTxTail returns the tx tails +func (r *accountsV2Reader) LoadTxTail(ctx context.Context, dbRound basics.Round) (roundData []*TxTailRound, roundHash []crypto.Digest, baseRound basics.Round, err error) { + rows, err := r.q.QueryContext(ctx, "SELECT rnd, data FROM txtail ORDER BY rnd DESC") + if err != nil { + return nil, nil, 0, err + } + defer rows.Close() + + expectedRound := dbRound + for rows.Next() { + var round basics.Round + var data []byte + err = rows.Scan(&round, &data) + if err != nil { + return nil, nil, 0, err + } + if round != expectedRound { + return nil, nil, 0, fmt.Errorf("txtail table contain unexpected round %d; round %d was expected", round, expectedRound) + } + tail := &TxTailRound{} + err = protocol.Decode(data, tail) + if err != nil { + return nil, nil, 0, err + } + roundData = append(roundData, tail) + roundHash = append(roundHash, crypto.Hash(data)) + expectedRound-- + } + // reverse the array ordering in-place so that it would be incremental order. + for i := 0; i < len(roundData)/2; i++ { + roundData[i], roundData[len(roundData)-i-1] = roundData[len(roundData)-i-1], roundData[i] + roundHash[i], roundHash[len(roundHash)-i-1] = roundHash[len(roundHash)-i-1], roundHash[i] + } + return roundData, roundHash, expectedRound + 1, nil +} + +// AccountsPutTotals updates account totals +func (w *accountsV2Writer) AccountsPutTotals(totals ledgercore.AccountTotals, catchpointStaging bool) error { + id := "" + if catchpointStaging { + id = "catchpointStaging" + } + _, err := w.e.Exec("REPLACE INTO accounttotals (id, online, onlinerewardunits, offline, offlinerewardunits, notparticipating, notparticipatingrewardunits, rewardslevel) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + id, + totals.Online.Money.Raw, totals.Online.RewardUnits, + totals.Offline.Money.Raw, totals.Offline.RewardUnits, + totals.NotParticipating.Money.Raw, totals.NotParticipating.RewardUnits, + totals.RewardsLevel) + return err +} + +func (w *accountsV2Writer) TxtailNewRound(ctx context.Context, baseRound basics.Round, roundData [][]byte, forgetBeforeRound basics.Round) error { + insertStmt, err := w.e.PrepareContext(ctx, "INSERT INTO txtail(rnd, data) VALUES(?, ?)") + if err != nil { + return err + } + defer insertStmt.Close() + + for i, data := range roundData { + _, err = insertStmt.ExecContext(ctx, int(baseRound)+i, data[:]) + if err != nil { + return err + } + } + + _, err = w.e.ExecContext(ctx, "DELETE FROM txtail WHERE rnd < ?", forgetBeforeRound) + return err +} + +// OnlineAccountsDelete deleted entries with updRound <= expRound +func (w *accountsV2Writer) OnlineAccountsDelete(forgetBefore basics.Round) (err error) { + rows, err := w.e.Query("SELECT rowid, address, updRound, data FROM onlineaccounts WHERE updRound < ? ORDER BY address, updRound DESC", forgetBefore) + if err != nil { + return err + } + defer rows.Close() + + var rowids []int64 + var rowid sql.NullInt64 + var updRound sql.NullInt64 + var buf []byte + var addrbuf []byte + + var prevAddr []byte + + for rows.Next() { + err = rows.Scan(&rowid, &addrbuf, &updRound, &buf) + if err != nil { + return err + } + if !rowid.Valid || !updRound.Valid { + return fmt.Errorf("onlineAccountsDelete: invalid rowid or updRound") + } + if len(addrbuf) != len(basics.Address{}) { + err = fmt.Errorf("account DB address length mismatch: %d != %d", len(addrbuf), len(basics.Address{})) + return + } + + if !bytes.Equal(addrbuf, prevAddr) { + // new address + // if the first (latest) entry is + // - offline then delete all + // - online then safe to delete all previous except this first (latest) + + // reset the state + prevAddr = addrbuf + + var oad BaseOnlineAccountData + err = protocol.Decode(buf, &oad) + if err != nil { + return + } + if oad.IsVotingEmpty() { + // delete this and all subsequent + rowids = append(rowids, rowid.Int64) + } + + // restart the loop + // if there are some subsequent entries, they will deleted on the next iteration + // if no subsequent entries, the loop will reset the state and the latest entry does not get deleted + continue + } + // delete all subsequent entries + rowids = append(rowids, rowid.Int64) + } + + return onlineAccountsDeleteByRowIDs(w.e, rowids) +} + +func onlineAccountsDeleteByRowIDs(e db.Executable, rowids []int64) (err error) { + if len(rowids) == 0 { + return + } + + // sqlite3 < 3.32.0 allows SQLITE_MAX_VARIABLE_NUMBER = 999 bindings + // see https://www.sqlite.org/limits.html + // rowids might be larger => split to chunks are remove + chunks := rowidsToChunkedArgs(rowids) + for _, chunk := range chunks { + _, err = e.Exec("DELETE FROM onlineaccounts WHERE rowid IN (?"+strings.Repeat(",?", len(chunk)-1)+")", chunk...) + if err != nil { + return + } + } + return +} + +func rowidsToChunkedArgs(rowids []int64) [][]interface{} { + const sqliteMaxVariableNumber = 999 + + numChunks := len(rowids)/sqliteMaxVariableNumber + 1 + if len(rowids)%sqliteMaxVariableNumber == 0 { + numChunks-- + } + chunks := make([][]interface{}, numChunks) + if numChunks == 1 { + // optimize memory consumption for the most common case + chunks[0] = make([]interface{}, len(rowids)) + for i, rowid := range rowids { + chunks[0][i] = interface{}(rowid) + } + } else { + for i := 0; i < numChunks; i++ { + chunkSize := sqliteMaxVariableNumber + if i == numChunks-1 { + chunkSize = len(rowids) - (numChunks-1)*sqliteMaxVariableNumber + } + chunks[i] = make([]interface{}, chunkSize) + } + for i, rowid := range rowids { + chunkIndex := i / sqliteMaxVariableNumber + chunks[chunkIndex][i%sqliteMaxVariableNumber] = interface{}(rowid) + } + } + return chunks +} + +// UpdateAccountsRound updates the round number associated with the current account data. +func (w *accountsV2Writer) UpdateAccountsRound(rnd basics.Round) (err error) { + res, err := w.e.Exec("UPDATE acctrounds SET rnd=? WHERE id='acctbase' AND rnd rnd { + err = fmt.Errorf("newRound %d is not after base %d", rnd, base) + return + } else if base != rnd { + err = fmt.Errorf("updateAccountsRound(acctbase, %d): expected to update 1 row but got %d", rnd, aff) + return + } + } + return +} diff --git a/ledger/store/accountsV2_test.go b/ledger/store/accountsV2_test.go new file mode 100644 index 0000000000..a9a5bf9864 --- /dev/null +++ b/ledger/store/accountsV2_test.go @@ -0,0 +1,88 @@ +// Copyright (C) 2019-2022 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package store + +import ( + "testing" + + "github.com/algorand/go-algorand/test/partitiontest" + "github.com/stretchr/testify/require" +) + +func TestRowidsToChunkedArgs(t *testing.T) { + partitiontest.PartitionTest(t) + + res := rowidsToChunkedArgs([]int64{1}) + require.Equal(t, 1, cap(res)) + require.Equal(t, 1, len(res)) + require.Equal(t, 1, cap(res[0])) + require.Equal(t, 1, len(res[0])) + require.Equal(t, []interface{}{int64(1)}, res[0]) + + input := make([]int64, 999) + for i := 0; i < len(input); i++ { + input[i] = int64(i) + } + res = rowidsToChunkedArgs(input) + require.Equal(t, 1, cap(res)) + require.Equal(t, 1, len(res)) + require.Equal(t, 999, cap(res[0])) + require.Equal(t, 999, len(res[0])) + for i := 0; i < len(input); i++ { + require.Equal(t, interface{}(int64(i)), res[0][i]) + } + + input = make([]int64, 1001) + for i := 0; i < len(input); i++ { + input[i] = int64(i) + } + res = rowidsToChunkedArgs(input) + require.Equal(t, 2, cap(res)) + require.Equal(t, 2, len(res)) + require.Equal(t, 999, cap(res[0])) + require.Equal(t, 999, len(res[0])) + require.Equal(t, 2, cap(res[1])) + require.Equal(t, 2, len(res[1])) + for i := 0; i < 999; i++ { + require.Equal(t, interface{}(int64(i)), res[0][i]) + } + j := 0 + for i := 999; i < len(input); i++ { + require.Equal(t, interface{}(int64(i)), res[1][j]) + j++ + } + + input = make([]int64, 2*999) + for i := 0; i < len(input); i++ { + input[i] = int64(i) + } + res = rowidsToChunkedArgs(input) + require.Equal(t, 2, cap(res)) + require.Equal(t, 2, len(res)) + require.Equal(t, 999, cap(res[0])) + require.Equal(t, 999, len(res[0])) + require.Equal(t, 999, cap(res[1])) + require.Equal(t, 999, len(res[1])) + for i := 0; i < 999; i++ { + require.Equal(t, interface{}(int64(i)), res[0][i]) + } + j = 0 + for i := 999; i < len(input); i++ { + require.Equal(t, interface{}(int64(i)), res[1][j]) + j++ + } +} diff --git a/ledger/store/data.go b/ledger/store/data.go index 006f43bead..6215b91c10 100644 --- a/ledger/store/data.go +++ b/ledger/store/data.go @@ -21,7 +21,10 @@ import ( "github.com/algorand/go-algorand/crypto" "github.com/algorand/go-algorand/crypto/merklesignature" "github.com/algorand/go-algorand/data/basics" + "github.com/algorand/go-algorand/data/bookkeeping" + "github.com/algorand/go-algorand/data/transactions" "github.com/algorand/go-algorand/ledger/ledgercore" + "github.com/algorand/go-algorand/protocol" ) // BaseAccountData is the base struct used to store account data @@ -210,6 +213,30 @@ type PersistedOnlineAccountData struct { UpdRound basics.Round } +// TxTailRound contains the information about a single round of transactions. +// The TxnIDs and LastValid would both be of the same length, and are stored +// in that way for efficient message=pack encoding. The Leases would point to the +// respective transaction index. Note that this isn’t optimized for storing +// leases, as leases are extremely rare. +type TxTailRound struct { + _struct struct{} `codec:",omitempty,omitemptyarray"` + + TxnIDs []transactions.Txid `codec:"i,allocbound=-"` + LastValid []basics.Round `codec:"v,allocbound=-"` + Leases []TxTailRoundLease `codec:"l,allocbound=-"` + Hdr bookkeeping.BlockHeader `codec:"h,allocbound=-"` +} + +// TxTailRoundLease is used as part of txTailRound for storing +// a single lease. +type TxTailRoundLease struct { + _struct struct{} `codec:",omitempty,omitemptyarray"` + + Sender basics.Address `codec:"s"` + Lease [32]byte `codec:"l,allocbound=-"` + TxnIdx uint64 `code:"i"` //!-- index of the entry in TxnIDs/LastValid +} + // AccountResource returns the corresponding account resource data based on the type of resource. func (prd *PersistedResourcesData) AccountResource() ledgercore.AccountResource { var ret ledgercore.AccountResource @@ -758,3 +785,38 @@ func (prd PersistedKVData) Before(other *PersistedKVData) bool { func (pac *PersistedOnlineAccountData) Before(other *PersistedOnlineAccountData) bool { return pac.UpdRound < other.UpdRound } + +// Encode the transaction tail data into a serialized form, and return the serialized data +// as well as the hash of the data. +func (t *TxTailRound) Encode() ([]byte, crypto.Digest) { + tailData := protocol.Encode(t) + hash := crypto.Hash(tailData) + return tailData, hash +} + +// TxTailRoundFromBlock creates a TxTailRound for the given block +func TxTailRoundFromBlock(blk bookkeeping.Block) (*TxTailRound, error) { + payset, err := blk.DecodePaysetFlat() + if err != nil { + return nil, err + } + + tail := &TxTailRound{} + + tail.TxnIDs = make([]transactions.Txid, len(payset)) + tail.LastValid = make([]basics.Round, len(payset)) + tail.Hdr = blk.BlockHeader + + for txIdxtxid, txn := range payset { + tail.TxnIDs[txIdxtxid] = txn.ID() + tail.LastValid[txIdxtxid] = txn.Txn.LastValid + if txn.Txn.Lease != [32]byte{} { + tail.Leases = append(tail.Leases, TxTailRoundLease{ + Sender: txn.Txn.Sender, + Lease: txn.Txn.Lease, + TxnIdx: uint64(txIdxtxid), + }) + } + } + return tail, nil +} diff --git a/ledger/store/msgp_gen.go b/ledger/store/msgp_gen.go index 0fcd1fb9d2..6597a5dffd 100644 --- a/ledger/store/msgp_gen.go +++ b/ledger/store/msgp_gen.go @@ -6,6 +6,8 @@ import ( "github.com/algorand/msgp/msgp" "github.com/algorand/go-algorand/config" + "github.com/algorand/go-algorand/data/basics" + "github.com/algorand/go-algorand/data/transactions" ) // The following msgp objects are implemented in this file: @@ -49,6 +51,22 @@ import ( // |-----> (*) Msgsize // |-----> (*) MsgIsZero // +// TxTailRound +// |-----> (*) MarshalMsg +// |-----> (*) CanMarshalMsg +// |-----> (*) UnmarshalMsg +// |-----> (*) CanUnmarshalMsg +// |-----> (*) Msgsize +// |-----> (*) MsgIsZero +// +// TxTailRoundLease +// |-----> (*) MarshalMsg +// |-----> (*) CanMarshalMsg +// |-----> (*) UnmarshalMsg +// |-----> (*) CanUnmarshalMsg +// |-----> (*) Msgsize +// |-----> (*) MsgIsZero +// // MarshalMsg implements msgp.Marshaler func (z *BaseAccountData) MarshalMsg(b []byte) (o []byte) { @@ -1878,3 +1896,459 @@ func (z *ResourcesData) Msgsize() (s int) { func (z *ResourcesData) MsgIsZero() bool { return ((*z).Total == 0) && ((*z).Decimals == 0) && ((*z).DefaultFrozen == false) && ((*z).UnitName == "") && ((*z).AssetName == "") && ((*z).URL == "") && ((*z).MetadataHash == ([32]byte{})) && ((*z).Manager.MsgIsZero()) && ((*z).Reserve.MsgIsZero()) && ((*z).Freeze.MsgIsZero()) && ((*z).Clawback.MsgIsZero()) && ((*z).Amount == 0) && ((*z).Frozen == false) && ((*z).SchemaNumUint == 0) && ((*z).SchemaNumByteSlice == 0) && ((*z).KeyValue.MsgIsZero()) && (len((*z).ApprovalProgram) == 0) && (len((*z).ClearStateProgram) == 0) && ((*z).GlobalState.MsgIsZero()) && ((*z).LocalStateSchemaNumUint == 0) && ((*z).LocalStateSchemaNumByteSlice == 0) && ((*z).GlobalStateSchemaNumUint == 0) && ((*z).GlobalStateSchemaNumByteSlice == 0) && ((*z).ExtraProgramPages == 0) && ((*z).ResourceFlags == 0) && ((*z).UpdateRound == 0) } + +// MarshalMsg implements msgp.Marshaler +func (z *TxTailRound) MarshalMsg(b []byte) (o []byte) { + o = msgp.Require(b, z.Msgsize()) + // omitempty: check for empty values + zb0004Len := uint32(4) + var zb0004Mask uint8 /* 5 bits */ + if (*z).Hdr.MsgIsZero() { + zb0004Len-- + zb0004Mask |= 0x2 + } + if len((*z).TxnIDs) == 0 { + zb0004Len-- + zb0004Mask |= 0x4 + } + if len((*z).Leases) == 0 { + zb0004Len-- + zb0004Mask |= 0x8 + } + if len((*z).LastValid) == 0 { + zb0004Len-- + zb0004Mask |= 0x10 + } + // variable map header, size zb0004Len + o = append(o, 0x80|uint8(zb0004Len)) + if zb0004Len != 0 { + if (zb0004Mask & 0x2) == 0 { // if not empty + // string "h" + o = append(o, 0xa1, 0x68) + o = (*z).Hdr.MarshalMsg(o) + } + if (zb0004Mask & 0x4) == 0 { // if not empty + // string "i" + o = append(o, 0xa1, 0x69) + if (*z).TxnIDs == nil { + o = msgp.AppendNil(o) + } else { + o = msgp.AppendArrayHeader(o, uint32(len((*z).TxnIDs))) + } + for zb0001 := range (*z).TxnIDs { + o = (*z).TxnIDs[zb0001].MarshalMsg(o) + } + } + if (zb0004Mask & 0x8) == 0 { // if not empty + // string "l" + o = append(o, 0xa1, 0x6c) + if (*z).Leases == nil { + o = msgp.AppendNil(o) + } else { + o = msgp.AppendArrayHeader(o, uint32(len((*z).Leases))) + } + for zb0003 := range (*z).Leases { + o = (*z).Leases[zb0003].MarshalMsg(o) + } + } + if (zb0004Mask & 0x10) == 0 { // if not empty + // string "v" + o = append(o, 0xa1, 0x76) + if (*z).LastValid == nil { + o = msgp.AppendNil(o) + } else { + o = msgp.AppendArrayHeader(o, uint32(len((*z).LastValid))) + } + for zb0002 := range (*z).LastValid { + o = (*z).LastValid[zb0002].MarshalMsg(o) + } + } + } + return +} + +func (_ *TxTailRound) CanMarshalMsg(z interface{}) bool { + _, ok := (z).(*TxTailRound) + return ok +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *TxTailRound) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0004 int + var zb0005 bool + zb0004, zb0005, bts, err = msgp.ReadMapHeaderBytes(bts) + if _, ok := err.(msgp.TypeError); ok { + zb0004, zb0005, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + if zb0004 > 0 { + zb0004-- + var zb0006 int + var zb0007 bool + zb0006, zb0007, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "TxnIDs") + return + } + if zb0007 { + (*z).TxnIDs = nil + } else if (*z).TxnIDs != nil && cap((*z).TxnIDs) >= zb0006 { + (*z).TxnIDs = ((*z).TxnIDs)[:zb0006] + } else { + (*z).TxnIDs = make([]transactions.Txid, zb0006) + } + for zb0001 := range (*z).TxnIDs { + bts, err = (*z).TxnIDs[zb0001].UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "TxnIDs", zb0001) + return + } + } + } + if zb0004 > 0 { + zb0004-- + var zb0008 int + var zb0009 bool + zb0008, zb0009, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "LastValid") + return + } + if zb0009 { + (*z).LastValid = nil + } else if (*z).LastValid != nil && cap((*z).LastValid) >= zb0008 { + (*z).LastValid = ((*z).LastValid)[:zb0008] + } else { + (*z).LastValid = make([]basics.Round, zb0008) + } + for zb0002 := range (*z).LastValid { + bts, err = (*z).LastValid[zb0002].UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "LastValid", zb0002) + return + } + } + } + if zb0004 > 0 { + zb0004-- + var zb0010 int + var zb0011 bool + zb0010, zb0011, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "Leases") + return + } + if zb0011 { + (*z).Leases = nil + } else if (*z).Leases != nil && cap((*z).Leases) >= zb0010 { + (*z).Leases = ((*z).Leases)[:zb0010] + } else { + (*z).Leases = make([]TxTailRoundLease, zb0010) + } + for zb0003 := range (*z).Leases { + bts, err = (*z).Leases[zb0003].UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "Leases", zb0003) + return + } + } + } + if zb0004 > 0 { + zb0004-- + bts, err = (*z).Hdr.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "Hdr") + return + } + } + if zb0004 > 0 { + err = msgp.ErrTooManyArrayFields(zb0004) + if err != nil { + err = msgp.WrapError(err, "struct-from-array") + return + } + } + } else { + if err != nil { + err = msgp.WrapError(err) + return + } + if zb0005 { + (*z) = TxTailRound{} + } + for zb0004 > 0 { + zb0004-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch string(field) { + case "i": + var zb0012 int + var zb0013 bool + zb0012, zb0013, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "TxnIDs") + return + } + if zb0013 { + (*z).TxnIDs = nil + } else if (*z).TxnIDs != nil && cap((*z).TxnIDs) >= zb0012 { + (*z).TxnIDs = ((*z).TxnIDs)[:zb0012] + } else { + (*z).TxnIDs = make([]transactions.Txid, zb0012) + } + for zb0001 := range (*z).TxnIDs { + bts, err = (*z).TxnIDs[zb0001].UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "TxnIDs", zb0001) + return + } + } + case "v": + var zb0014 int + var zb0015 bool + zb0014, zb0015, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "LastValid") + return + } + if zb0015 { + (*z).LastValid = nil + } else if (*z).LastValid != nil && cap((*z).LastValid) >= zb0014 { + (*z).LastValid = ((*z).LastValid)[:zb0014] + } else { + (*z).LastValid = make([]basics.Round, zb0014) + } + for zb0002 := range (*z).LastValid { + bts, err = (*z).LastValid[zb0002].UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "LastValid", zb0002) + return + } + } + case "l": + var zb0016 int + var zb0017 bool + zb0016, zb0017, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Leases") + return + } + if zb0017 { + (*z).Leases = nil + } else if (*z).Leases != nil && cap((*z).Leases) >= zb0016 { + (*z).Leases = ((*z).Leases)[:zb0016] + } else { + (*z).Leases = make([]TxTailRoundLease, zb0016) + } + for zb0003 := range (*z).Leases { + bts, err = (*z).Leases[zb0003].UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "Leases", zb0003) + return + } + } + case "h": + bts, err = (*z).Hdr.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "Hdr") + return + } + default: + err = msgp.ErrNoField(string(field)) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + } + o = bts + return +} + +func (_ *TxTailRound) CanUnmarshalMsg(z interface{}) bool { + _, ok := (z).(*TxTailRound) + return ok +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *TxTailRound) Msgsize() (s int) { + s = 1 + 2 + msgp.ArrayHeaderSize + for zb0001 := range (*z).TxnIDs { + s += (*z).TxnIDs[zb0001].Msgsize() + } + s += 2 + msgp.ArrayHeaderSize + for zb0002 := range (*z).LastValid { + s += (*z).LastValid[zb0002].Msgsize() + } + s += 2 + msgp.ArrayHeaderSize + for zb0003 := range (*z).Leases { + s += (*z).Leases[zb0003].Msgsize() + } + s += 2 + (*z).Hdr.Msgsize() + return +} + +// MsgIsZero returns whether this is a zero value +func (z *TxTailRound) MsgIsZero() bool { + return (len((*z).TxnIDs) == 0) && (len((*z).LastValid) == 0) && (len((*z).Leases) == 0) && ((*z).Hdr.MsgIsZero()) +} + +// MarshalMsg implements msgp.Marshaler +func (z *TxTailRoundLease) MarshalMsg(b []byte) (o []byte) { + o = msgp.Require(b, z.Msgsize()) + // omitempty: check for empty values + zb0002Len := uint32(3) + var zb0002Mask uint8 /* 4 bits */ + if (*z).TxnIdx == 0 { + zb0002Len-- + zb0002Mask |= 0x1 + } + if (*z).Lease == ([32]byte{}) { + zb0002Len-- + zb0002Mask |= 0x4 + } + if (*z).Sender.MsgIsZero() { + zb0002Len-- + zb0002Mask |= 0x8 + } + // variable map header, size zb0002Len + o = append(o, 0x80|uint8(zb0002Len)) + if zb0002Len != 0 { + if (zb0002Mask & 0x1) == 0 { // if not empty + // string "TxnIdx" + o = append(o, 0xa6, 0x54, 0x78, 0x6e, 0x49, 0x64, 0x78) + o = msgp.AppendUint64(o, (*z).TxnIdx) + } + if (zb0002Mask & 0x4) == 0 { // if not empty + // string "l" + o = append(o, 0xa1, 0x6c) + o = msgp.AppendBytes(o, ((*z).Lease)[:]) + } + if (zb0002Mask & 0x8) == 0 { // if not empty + // string "s" + o = append(o, 0xa1, 0x73) + o = (*z).Sender.MarshalMsg(o) + } + } + return +} + +func (_ *TxTailRoundLease) CanMarshalMsg(z interface{}) bool { + _, ok := (z).(*TxTailRoundLease) + return ok +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *TxTailRoundLease) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0002 int + var zb0003 bool + zb0002, zb0003, bts, err = msgp.ReadMapHeaderBytes(bts) + if _, ok := err.(msgp.TypeError); ok { + zb0002, zb0003, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + if zb0002 > 0 { + zb0002-- + bts, err = (*z).Sender.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "Sender") + return + } + } + if zb0002 > 0 { + zb0002-- + bts, err = msgp.ReadExactBytes(bts, ((*z).Lease)[:]) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "Lease") + return + } + } + if zb0002 > 0 { + zb0002-- + (*z).TxnIdx, bts, err = msgp.ReadUint64Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "struct-from-array", "TxnIdx") + return + } + } + if zb0002 > 0 { + err = msgp.ErrTooManyArrayFields(zb0002) + if err != nil { + err = msgp.WrapError(err, "struct-from-array") + return + } + } + } else { + if err != nil { + err = msgp.WrapError(err) + return + } + if zb0003 { + (*z) = TxTailRoundLease{} + } + for zb0002 > 0 { + zb0002-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch string(field) { + case "s": + bts, err = (*z).Sender.UnmarshalMsg(bts) + if err != nil { + err = msgp.WrapError(err, "Sender") + return + } + case "l": + bts, err = msgp.ReadExactBytes(bts, ((*z).Lease)[:]) + if err != nil { + err = msgp.WrapError(err, "Lease") + return + } + case "TxnIdx": + (*z).TxnIdx, bts, err = msgp.ReadUint64Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "TxnIdx") + return + } + default: + err = msgp.ErrNoField(string(field)) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + } + o = bts + return +} + +func (_ *TxTailRoundLease) CanUnmarshalMsg(z interface{}) bool { + _, ok := (z).(*TxTailRoundLease) + return ok +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *TxTailRoundLease) Msgsize() (s int) { + s = 1 + 2 + (*z).Sender.Msgsize() + 2 + msgp.ArrayHeaderSize + (32 * (msgp.ByteSize)) + 7 + msgp.Uint64Size + return +} + +// MsgIsZero returns whether this is a zero value +func (z *TxTailRoundLease) MsgIsZero() bool { + return ((*z).Sender.MsgIsZero()) && ((*z).Lease == ([32]byte{})) && ((*z).TxnIdx == 0) +} diff --git a/ledger/store/msgp_gen_test.go b/ledger/store/msgp_gen_test.go index a6d2ec2fb4..ddb530a95c 100644 --- a/ledger/store/msgp_gen_test.go +++ b/ledger/store/msgp_gen_test.go @@ -253,3 +253,123 @@ func BenchmarkUnmarshalResourcesData(b *testing.B) { } } } + +func TestMarshalUnmarshalTxTailRound(t *testing.T) { + partitiontest.PartitionTest(t) + v := TxTailRound{} + bts := v.MarshalMsg(nil) + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func TestRandomizedEncodingTxTailRound(t *testing.T) { + protocol.RunEncodingTest(t, &TxTailRound{}) +} + +func BenchmarkMarshalMsgTxTailRound(b *testing.B) { + v := TxTailRound{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgTxTailRound(b *testing.B) { + v := TxTailRound{} + bts := make([]byte, 0, v.Msgsize()) + bts = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalTxTailRound(b *testing.B) { + v := TxTailRound{} + bts := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalTxTailRoundLease(t *testing.T) { + partitiontest.PartitionTest(t) + v := TxTailRoundLease{} + bts := v.MarshalMsg(nil) + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func TestRandomizedEncodingTxTailRoundLease(t *testing.T) { + protocol.RunEncodingTest(t, &TxTailRoundLease{}) +} + +func BenchmarkMarshalMsgTxTailRoundLease(b *testing.B) { + v := TxTailRoundLease{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgTxTailRoundLease(b *testing.B) { + v := TxTailRoundLease{} + bts := make([]byte, 0, v.Msgsize()) + bts = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalTxTailRoundLease(b *testing.B) { + v := TxTailRoundLease{} + bts := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/ledger/tracker.go b/ledger/tracker.go index 94d2dbd9c3..1612856444 100644 --- a/ledger/tracker.go +++ b/ledger/tracker.go @@ -280,7 +280,8 @@ func (tr *trackerRegistry) initialize(l ledgerForTracker, trackers []ledgerTrack tr.log = l.trackerLog() err = tr.dbs.Rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - tr.dbRound, err = accountsRound(tx) + arw := store.NewAccountsSQLReaderWriter(tx) + tr.dbRound, err = arw.AccountsRound() return err }) @@ -510,6 +511,7 @@ func (tr *trackerRegistry) commitRound(dcc *deferredCommitContext) error { start := time.Now() ledgerCommitroundCount.Inc(nil) err := tr.dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { + arw := store.NewAccountsSQLReaderWriter(tx) for _, lt := range tr.trackers { err0 := lt.commitRound(ctx, tx, dcc) if err0 != nil { @@ -517,7 +519,7 @@ func (tr *trackerRegistry) commitRound(dcc *deferredCommitContext) error { } } - return updateAccountsRound(tx, dbRound+basics.Round(offset)) + return arw.UpdateAccountsRound(dbRound + basics.Round(offset)) }) ledgerCommitroundMicros.AddMicrosecondsSince(start, nil) diff --git a/ledger/trackerdb.go b/ledger/trackerdb.go index bb9bf9a8ce..f9190a8b44 100644 --- a/ledger/trackerdb.go +++ b/ledger/trackerdb.go @@ -80,6 +80,8 @@ func trackerDBInitialize(l ledgerForTracker, catchpointEnabled bool, dbPathPrefi } err = dbs.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) error { + arw := store.NewAccountsSQLReaderWriter(tx) + tp := trackerDBParams{ initAccounts: l.GenesisAccounts(), initProto: l.GenesisProtoVersion(), @@ -94,7 +96,7 @@ func trackerDBInitialize(l ledgerForTracker, catchpointEnabled bool, dbPathPrefi if err0 != nil { return err0 } - lastBalancesRound, err := accountsRound(tx) + lastBalancesRound, err := arw.AccountsRound() if err != nil { return err } diff --git a/ledger/txtail.go b/ledger/txtail.go index 4fa2ab0a4b..4b045ef91f 100644 --- a/ledger/txtail.go +++ b/ledger/txtail.go @@ -29,6 +29,7 @@ import ( "github.com/algorand/go-algorand/data/bookkeeping" "github.com/algorand/go-algorand/data/transactions" "github.com/algorand/go-algorand/ledger/ledgercore" + "github.com/algorand/go-algorand/ledger/store" "github.com/algorand/go-algorand/logging" ) @@ -93,12 +94,13 @@ func (t *txTail) loadFromDisk(l ledgerForTracker, dbRound basics.Round) error { rdb := l.trackerDB().Rdb t.log = l.trackerLog() - var roundData []*txTailRound + var roundData []*store.TxTailRound var roundTailHashes []crypto.Digest var baseRound basics.Round if dbRound > 0 { err := rdb.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { - roundData, roundTailHashes, baseRound, err = loadTxTail(ctx, tx, dbRound) + arw := store.NewAccountsSQLReaderWriter(tx) + roundData, roundTailHashes, baseRound, err = arw.LoadTxTail(ctx, dbRound) return }) if err != nil { @@ -192,7 +194,7 @@ func (t *txTail) newBlock(blk bookkeeping.Block, delta ledgercore.StateDelta) { return } - var tail txTailRound + var tail store.TxTailRound tail.TxnIDs = make([]transactions.Txid, len(delta.Txids)) tail.LastValid = make([]basics.Round, len(delta.Txids)) tail.Hdr = blk.BlockHeader @@ -202,14 +204,14 @@ func (t *txTail) newBlock(blk bookkeeping.Block, delta ledgercore.StateDelta) { tail.TxnIDs[txnInc.Intra] = txid tail.LastValid[txnInc.Intra] = txnInc.LastValid if blk.Payset[txnInc.Intra].Txn.Lease != [32]byte{} { - tail.Leases = append(tail.Leases, txTailRoundLease{ + tail.Leases = append(tail.Leases, store.TxTailRoundLease{ Sender: blk.Payset[txnInc.Intra].Txn.Sender, Lease: blk.Payset[txnInc.Intra].Txn.Lease, TxnIdx: txnInc.Intra, }) } } - encodedTail, tailHash := tail.encode() + encodedTail, tailHash := tail.Encode() t.tailMu.Lock() defer t.tailMu.Unlock() @@ -269,11 +271,13 @@ func (t *txTail) prepareCommit(dcc *deferredCommitContext) (err error) { } func (t *txTail) commitRound(ctx context.Context, tx *sql.Tx, dcc *deferredCommitContext) error { + arw := store.NewAccountsSQLReaderWriter(tx) + // determine the round to remove data // the formula is similar to the committedUpTo: rnd + 1 - retain size forgetBeforeRound := (dcc.newBase + 1).SubSaturate(basics.Round(dcc.txTailRetainSize)) baseRound := dcc.oldBase + 1 - if err := txtailNewRound(ctx, tx, baseRound, dcc.txTailDeltas, forgetBeforeRound); err != nil { + if err := arw.TxtailNewRound(ctx, baseRound, dcc.txTailDeltas, forgetBeforeRound); err != nil { return fmt.Errorf("txTail: unable to persist new round %d : %w", baseRound, err) } return nil diff --git a/ledger/txtail_test.go b/ledger/txtail_test.go index d32ee7568f..079fff2835 100644 --- a/ledger/txtail_test.go +++ b/ledger/txtail_test.go @@ -30,6 +30,7 @@ import ( "github.com/algorand/go-algorand/data/bookkeeping" "github.com/algorand/go-algorand/data/transactions" "github.com/algorand/go-algorand/ledger/ledgercore" + "github.com/algorand/go-algorand/ledger/store" ledgertesting "github.com/algorand/go-algorand/ledger/testing" "github.com/algorand/go-algorand/protocol" "github.com/algorand/go-algorand/test/partitiontest" @@ -154,6 +155,8 @@ func (t *txTailTestLedger) initialize(ts *testing.T, protoVersion protocol.Conse tx, err := t.trackerDBs.Wdb.Handle.Begin() require.NoError(ts, err) + arw := store.NewAccountsSQLReaderWriter(tx) + accts := ledgertesting.RandomAccounts(20, true) proto := config.Consensus[protoVersion] newDB := accountsInitTest(ts, tx, accts, protoVersion) @@ -166,12 +169,12 @@ func (t *txTailTestLedger) initialize(ts *testing.T, protoVersion protocol.Conse for i := startRound; i <= t.Latest(); i++ { blk, err := t.Block(i) require.NoError(ts, err) - tail, err := txTailRoundFromBlock(blk) + tail, err := store.TxTailRoundFromBlock(blk) require.NoError(ts, err) - encoded, _ := tail.encode() + encoded, _ := tail.Encode() roundData = append(roundData, encoded) } - err = txtailNewRound(context.Background(), tx, startRound, roundData, 0) + err = arw.TxtailNewRound(context.Background(), startRound, roundData, 0) require.NoError(ts, err) tx.Commit() return nil diff --git a/util/db/interfaces.go b/util/db/interfaces.go index d607c08d68..5d4abed852 100644 --- a/util/db/interfaces.go +++ b/util/db/interfaces.go @@ -31,6 +31,7 @@ import ( // be added here as needed. type Queryable interface { Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) Query(query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row