Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Use Custom Priority in Priority Nonce Mempool #15328

Merged
merged 12 commits into from
Mar 15, 2023
138 changes: 88 additions & 50 deletions types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package mempool
import (
"context"
"fmt"
"math"

"github.com/huandu/skiplist"

Expand All @@ -16,6 +15,40 @@ var (
_ Iterator = (*PriorityNonceIterator)(nil)
)

type TxPriority struct {
// GetTxPriority returns the priority of the transaction. A priority must be
// comparable via CompareTxPriority.
GetTxPriority func(ctx context.Context, tx sdk.Tx) any
// CompareTxPriority compares two transaction priorities. The result should be
// 0 if a == b, -1 if a < b, and +1 if a > b.
CompareTxPriority func(a, b any) int
}

// NewDefaultTxPriority returns a TxPriority comparator using ctx.Priority as
// the defining transaction priority.
func NewDefaultTxPriority() TxPriority {
return TxPriority{
GetTxPriority: func(goCtx context.Context, tx sdk.Tx) any {
return sdk.UnwrapSDKContext(goCtx).Priority()
},
CompareTxPriority: func(a, b any) int {
switch {
case a == nil && b == nil:
alexanderbez marked this conversation as resolved.
Show resolved Hide resolved
return 0
case a == nil:
alexanderbez marked this conversation as resolved.
Show resolved Hide resolved
return -1
case b == nil:
return 1
default:
aPriority := a.(int64)
bPriority := b.(int64)

return skiplist.Int64.Compare(aPriority, bPriority)
}
},
}
}

// PriorityNonceMempool is a mempool implementation that stores txs
// in a partially ordered set by 2 dimensions: priority, and sender-nonce
// (sequence number). Internally it uses one priority ordered skip list and one
Expand All @@ -25,65 +58,70 @@ var (
// and priority.
type PriorityNonceMempool struct {
priorityIndex *skiplist.SkipList
priorityCounts map[int64]int
priorityCounts map[any]int
senderIndices map[string]*skiplist.SkipList
scores map[txMeta]txMeta
onRead func(tx sdk.Tx)
txReplacement func(op, np int64, oTx, nTx sdk.Tx) bool
txReplacement func(op, np any, oTx, nTx sdk.Tx) bool
maxTx int
txPriority TxPriority
}

type PriorityNonceIterator struct {
mempool *PriorityNonceMempool
priorityNode *skiplist.Element
senderCursors map[string]*skiplist.Element
nextPriority int64
sender string
priorityNode *skiplist.Element
mempool *PriorityNonceMempool
nextPriority any
}

// txMeta stores transaction metadata used in indices
type txMeta struct {
// nonce is the sender's sequence number
nonce uint64
// priority is the transaction's priority
priority int64
priority any
// sender is the transaction's sender
sender string
// weight is the transaction's weight, used as a tiebreaker for transactions with the same priority
weight int64
// weight is the transaction's weight, used as a tiebreaker for transactions
// with the same priority
weight any
// senderElement is a pointer to the transaction's element in the sender index
senderElement *skiplist.Element
}

// txMetaLess is a comparator for txKeys that first compares priority, then weight,
// then sender, then nonce, uniquely identifying a transaction.
// skiplistComparable is a comparator for txKeys that first compares priority,
// then weight, then sender, then nonce, uniquely identifying a transaction.
//
// Note, txMetaLess is used as the comparator in the priority index.
func txMetaLess(a, b any) int {
keyA := a.(txMeta)
keyB := b.(txMeta)
res := skiplist.Int64.Compare(keyA.priority, keyB.priority)
if res != 0 {
return res
}
// Note, skiplistComparable is used as the comparator in the priority index.
func skiplistComparable(txPriority TxPriority) skiplist.Comparable {
return skiplist.LessThanFunc(func(a, b any) int {
keyA := a.(txMeta)
keyB := b.(txMeta)

res := txPriority.CompareTxPriority(keyA.priority, keyB.priority)
if res != 0 {
return res
}

// Weight is used as a tiebreaker for transactions with the same priority.
// Weight is calculated in a single pass in .Select(...) and so will be 0
// on .Insert(...).
res = skiplist.Int64.Compare(keyA.weight, keyB.weight)
if res != 0 {
return res
}
// Weight is used as a tiebreaker for transactions with the same priority.
// Weight is calculated in a single pass in .Select(...) and so will be 0
// on .Insert(...).
res = txPriority.CompareTxPriority(keyA.weight, keyB.weight)
if res != 0 {
return res
}

// Because weight will be 0 on .Insert(...), we must also compare sender and
// nonce to resolve priority collisions. If we didn't then transactions with
// the same priority would overwrite each other in the priority index.
res = skiplist.String.Compare(keyA.sender, keyB.sender)
if res != 0 {
return res
}
// Because weight will be 0 on .Insert(...), we must also compare sender and
// nonce to resolve priority collisions. If we didn't then transactions with
// the same priority would overwrite each other in the priority index.
res = skiplist.String.Compare(keyA.sender, keyB.sender)
if res != 0 {
return res
}

return skiplist.Uint64.Compare(keyA.nonce, keyB.nonce)
return skiplist.Uint64.Compare(keyA.nonce, keyB.nonce)
})
}

type PriorityNonceMempoolOption func(*PriorityNonceMempool)
Expand All @@ -99,7 +137,7 @@ func PriorityNonceWithOnRead(onRead func(tx sdk.Tx)) PriorityNonceMempoolOption
// PriorityNonceWithTxReplacement sets a callback to be called when duplicated
// transaction nonce detected during mempool insert. An application can define a
// transaction replacement rule based on tx priority or certain transaction fields.
func PriorityNonceWithTxReplacement(txReplacementRule func(op, np int64, oTx, nTx sdk.Tx) bool) PriorityNonceMempoolOption {
func PriorityNonceWithTxReplacement(txReplacementRule func(op, np any, oTx, nTx sdk.Tx) bool) PriorityNonceMempoolOption {
return func(mp *PriorityNonceMempool) {
mp.txReplacement = txReplacementRule
}
Expand All @@ -118,18 +156,19 @@ func PriorityNonceWithMaxTx(maxTx int) PriorityNonceMempoolOption {
}

// DefaultPriorityMempool returns a priorityNonceMempool with no options.
func DefaultPriorityMempool() Mempool {
return NewPriorityMempool()
func DefaultPriorityMempool(txPriority TxPriority) Mempool {
return NewPriorityMempool(txPriority)
}

// NewPriorityMempool returns the SDK's default mempool implementation which
// returns txs in a partial order by 2 dimensions; priority, and sender-nonce.
func NewPriorityMempool(opts ...PriorityNonceMempoolOption) *PriorityNonceMempool {
func NewPriorityMempool(txPriority TxPriority, opts ...PriorityNonceMempoolOption) *PriorityNonceMempool {
mp := &PriorityNonceMempool{
priorityIndex: skiplist.New(skiplist.LessThanFunc(txMetaLess)),
priorityCounts: make(map[int64]int),
priorityIndex: skiplist.New(skiplistComparable(txPriority)),
priorityCounts: make(map[any]int),
senderIndices: make(map[string]*skiplist.SkipList),
scores: make(map[txMeta]txMeta),
txPriority: txPriority,
}

for _, opt := range opts {
Expand Down Expand Up @@ -176,10 +215,9 @@ func (mp *PriorityNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error {
return fmt.Errorf("tx must have at least one signer")
}

sdkContext := sdk.UnwrapSDKContext(ctx)
priority := sdkContext.Priority()
sig := sigs[0]
sender := sdk.AccAddress(sig.PubKey.Address()).String()
priority := mp.txPriority.GetTxPriority(ctx, tx)
nonce := sig.Sequence
key := txMeta{nonce: nonce, priority: priority, sender: sender}

Expand Down Expand Up @@ -252,7 +290,7 @@ func (i *PriorityNonceIterator) iteratePriority() Iterator {
if nextPriorityNode != nil {
i.nextPriority = nextPriorityNode.Key().(txMeta).priority
} else {
i.nextPriority = math.MinInt64
i.nextPriority = nil
}

return i.Next()
Expand Down Expand Up @@ -281,13 +319,13 @@ func (i *PriorityNonceIterator) Next() Iterator {

// We've reached a transaction with a priority lower than the next highest
// priority in the pool.
if key.priority < i.nextPriority {
if i.mempool.txPriority.CompareTxPriority(key.priority, i.nextPriority) < 0 {
return i.iteratePriority()
} else if key.priority == i.nextPriority {
} else if i.mempool.txPriority.CompareTxPriority(key.priority, i.nextPriority) == 0 {
// Weight is incorporated into the priority index key only (not sender index)
// so we must fetch it here from the scores map.
weight := i.mempool.scores[txMeta{nonce: key.nonce, sender: key.sender}].weight
if weight < i.priorityNode.Next().Key().(txMeta).weight {
if i.mempool.txPriority.CompareTxPriority(weight, i.priorityNode.Next().Key().(txMeta).weight) < 0 {
return i.iteratePriority()
}
}
Expand Down Expand Up @@ -335,7 +373,7 @@ func (mp *PriorityNonceMempool) reorderPriorityTies() {
key := node.Key().(txMeta)
if mp.priorityCounts[key.priority] > 1 {
newKey := key
newKey.weight = senderWeight(key.senderElement)
newKey.weight = senderWeight(mp.txPriority, key.senderElement)
reordering = append(reordering, reorderKey{deleteKey: key, insertKey: newKey, tx: node.Value.(sdk.Tx)})
}

Expand All @@ -354,7 +392,7 @@ func (mp *PriorityNonceMempool) reorderPriorityTies() {
// defined as the first (nonce-wise) same sender tx with a priority not equal to
// t. It is used to resolve priority collisions, that is when 2 or more txs from
// different senders have the same priority.
func senderWeight(senderCursor *skiplist.Element) int64 {
func senderWeight(txPriority TxPriority, senderCursor *skiplist.Element) any {
if senderCursor == nil {
return 0
}
Expand All @@ -363,7 +401,7 @@ func senderWeight(senderCursor *skiplist.Element) int64 {
senderCursor = senderCursor.Next()
for senderCursor != nil {
p := senderCursor.Key().(txMeta).priority
if p != weight {
if txPriority.CompareTxPriority(p, weight) != 0 {
weight = p
}

Expand Down Expand Up @@ -419,7 +457,7 @@ func IsEmpty(mempool Mempool) error {
return fmt.Errorf("priorityIndex not empty")
}

var countKeys []int64
var countKeys []any
for k := range mp.priorityCounts {
countKeys = append(countKeys, k)
}
Expand Down
40 changes: 23 additions & 17 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (
"testing"
"time"

"cosmossdk.io/log"
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
"github.com/stretchr/testify/require"

"cosmossdk.io/log"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/mempool"
simtypes "github.com/cosmos/cosmos-sdk/types/simulation"
Expand Down Expand Up @@ -229,7 +229,7 @@ func (s *MempoolTestSuite) TestPriorityNonceTxOrder() {
}
for i, tt := range tests {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
pool := mempool.NewPriorityMempool()
pool := mempool.NewPriorityMempool(mempool.NewDefaultTxPriority())

// create test txs and insert into mempool
for i, ts := range tt.txs {
Expand Down Expand Up @@ -275,7 +275,7 @@ func (s *MempoolTestSuite) TestPriorityTies() {
}

for i := 0; i < 100; i++ {
s.mempool = mempool.NewPriorityMempool()
s.mempool = mempool.NewPriorityMempool(mempool.NewDefaultTxPriority())
var shuffled []txSpec
for _, t := range txSet {
tx := txSpec{
Expand Down Expand Up @@ -372,9 +372,12 @@ func validateOrder(mtxs []sdk.Tx) error {

func (s *MempoolTestSuite) TestRandomGeneratedTxs() {
s.iterations = 0
s.mempool = mempool.NewPriorityMempool(mempool.PriorityNonceWithOnRead(func(tx sdk.Tx) {
s.iterations++
}))
s.mempool = mempool.NewPriorityMempool(
mempool.NewDefaultTxPriority(),
mempool.PriorityNonceWithOnRead(func(tx sdk.Tx) {
s.iterations++
}),
)
t := s.T()
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
seed := time.Now().UnixNano()
Expand Down Expand Up @@ -409,7 +412,7 @@ func (s *MempoolTestSuite) TestRandomGeneratedTxs() {

func (s *MempoolTestSuite) TestRandomWalkTxs() {
s.iterations = 0
s.mempool = mempool.NewPriorityMempool()
s.mempool = mempool.NewPriorityMempool(mempool.NewDefaultTxPriority())

t := s.T()
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
Expand Down Expand Up @@ -589,7 +592,7 @@ func TestPriorityNonceMempool_NextSenderTx(t *testing.T) {
accA := accounts[0].Address
accB := accounts[1].Address

mp := mempool.NewPriorityMempool()
mp := mempool.NewPriorityMempool(mempool.NewDefaultTxPriority())

txs := []testTx{
{priority: 20, nonce: 1, address: accA},
Expand Down Expand Up @@ -633,21 +636,21 @@ func TestNextSenderTx_TxLimit(t *testing.T) {
}

// unlimited
mp := mempool.NewPriorityMempool(mempool.PriorityNonceWithMaxTx(0))
mp := mempool.NewPriorityMempool(mempool.NewDefaultTxPriority(), mempool.PriorityNonceWithMaxTx(0))
for i, tx := range txs {
c := ctx.WithPriority(tx.priority)
require.NoError(t, mp.Insert(c, tx))
require.Equal(t, i+1, mp.CountTx())
}
mp = mempool.NewPriorityMempool()
mp = mempool.NewPriorityMempool(mempool.NewDefaultTxPriority())
for i, tx := range txs {
c := ctx.WithPriority(tx.priority)
require.NoError(t, mp.Insert(c, tx))
require.Equal(t, i+1, mp.CountTx())
}

// limit: 3
mp = mempool.NewPriorityMempool(mempool.PriorityNonceWithMaxTx(3))
mp = mempool.NewPriorityMempool(mempool.NewDefaultTxPriority(), mempool.PriorityNonceWithMaxTx(3))
for i, tx := range txs {
c := ctx.WithPriority(tx.priority)
err := mp.Insert(c, tx)
Expand All @@ -661,7 +664,7 @@ func TestNextSenderTx_TxLimit(t *testing.T) {
}

// disabled
mp = mempool.NewPriorityMempool(mempool.PriorityNonceWithMaxTx(-1))
mp = mempool.NewPriorityMempool(mempool.NewDefaultTxPriority(), mempool.PriorityNonceWithMaxTx(-1))
for _, tx := range txs {
c := ctx.WithPriority(tx.priority)
err := mp.Insert(c, tx)
Expand All @@ -683,7 +686,7 @@ func TestNextSenderTx_TxReplacement(t *testing.T) {
}

// test Priority with default mempool
mp := mempool.NewPriorityMempool()
mp := mempool.NewPriorityMempool(mempool.NewDefaultTxPriority())
for _, tx := range txs {
c := ctx.WithPriority(tx.priority)
require.NoError(t, mp.Insert(c, tx))
Expand All @@ -697,10 +700,13 @@ func TestNextSenderTx_TxReplacement(t *testing.T) {
// we set a TestTxReplacement rule which the priority of the new Tx must be 20% more than the priority of the old Tx
// otherwise, the Insert will return error
feeBump := 20
mp = mempool.NewPriorityMempool(mempool.PriorityNonceWithTxReplacement(func(op, np int64, oTx, nTx sdk.Tx) bool {
threshold := int64(100 + feeBump)
return np >= op*threshold/100
}))
mp = mempool.NewPriorityMempool(
mempool.NewDefaultTxPriority(),
mempool.PriorityNonceWithTxReplacement(func(op, np any, oTx, nTx sdk.Tx) bool {
threshold := int64(100 + feeBump)
return np.(int64) >= op.(int64)*threshold/100
}),
)

c := ctx.WithPriority(txs[0].priority)
require.NoError(t, mp.Insert(c, txs[0]))
Expand Down