diff --git a/common/celo_types.go b/common/celo_types.go index 20920bbed2..44ac2ef072 100644 --- a/common/celo_types.go +++ b/common/celo_types.go @@ -1,5 +1,34 @@ package common +import ( + "math/big" +) + var ( ZeroAddress = BytesToAddress([]byte{}) ) + +type ExchangeRates = map[Address]*big.Rat + +func IsCurrencyWhitelisted(exchangeRates ExchangeRates, feeCurrency *Address) bool { + if feeCurrency == nil { + return true + } + + // Check if fee currency is registered + _, ok := exchangeRates[*feeCurrency] + return ok +} + +func AreSameAddress(a, b *Address) bool { + // both are nil or point to the same address + if a == b { + return true + } + // if only one is nil + if a == nil || b == nil { + return false + } + // if they point to the same + return *a == *b +} diff --git a/common/celo_types_test.go b/common/celo_types_test.go new file mode 100644 index 0000000000..4ecfe47182 --- /dev/null +++ b/common/celo_types_test.go @@ -0,0 +1,47 @@ +package common + +import ( + "math/big" + "testing" +) + +var ( + currA = HexToAddress("0xA") + currB = HexToAddress("0xB") + currX = HexToAddress("0xF") + exchangeRates = ExchangeRates{ + currA: big.NewRat(47, 100), + currB: big.NewRat(45, 100), + } +) + +func TestIsWhitelisted(t *testing.T) { + tests := []struct { + name string + feeCurrency *Address + want bool + }{ + { + name: "no fee currency", + feeCurrency: nil, + want: true, + }, + { + name: "valid fee currency", + feeCurrency: &currA, + want: true, + }, + { + name: "invalid fee currency", + feeCurrency: &currX, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsCurrencyWhitelisted(exchangeRates, tt.feeCurrency); got != tt.want { + t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/common/exchange/rates.go b/common/exchange/rates.go new file mode 100644 index 0000000000..9a93c62ac8 --- /dev/null +++ b/common/exchange/rates.go @@ -0,0 +1,102 @@ +package exchange + +import ( + "errors" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/log" +) + +var ( + unitRate = big.NewRat(1, 1) + // ErrNonWhitelistedFeeCurrency is returned if the currency specified to use for the fees + // isn't one of the currencies whitelisted for that purpose. + ErrNonWhitelistedFeeCurrency = errors.New("non-whitelisted fee currency address") +) + +// ConvertCurrency does an exchange conversion from currencyFrom to currencyTo of the value given. +func ConvertCurrency(exchangeRates common.ExchangeRates, val1 *big.Int, currencyFrom *common.Address, currencyTo *common.Address) *big.Int { + goldAmount, err := ConvertCurrencyToGold(exchangeRates, val1, currencyFrom) + if err != nil { + log.Error("Error trying to convert from currency to gold.", "value", val1, "fromCurrency", currencyFrom.Hex()) + } + toAmount, err := ConvertGoldToCurrency(exchangeRates, currencyTo, goldAmount) + if err != nil { + log.Error("Error trying to convert from gold to currency.", "value", goldAmount, "toCurrency", currencyTo.Hex()) + } + return toAmount +} + +func ConvertCurrencyToGold(exchangeRates common.ExchangeRates, currencyAmount *big.Int, feeCurrency *common.Address) (*big.Int, error) { + if feeCurrency == nil { + return currencyAmount, nil + } + exchangeRate, ok := exchangeRates[*feeCurrency] + if !ok { + return nil, ErrNonWhitelistedFeeCurrency + } + return new(big.Int).Div(new(big.Int).Mul(currencyAmount, exchangeRate.Denom()), exchangeRate.Num()), nil +} + +func ConvertGoldToCurrency(exchangeRates common.ExchangeRates, feeCurrency *common.Address, goldAmount *big.Int) (*big.Int, error) { + if feeCurrency == nil { + return goldAmount, nil + } + exchangeRate, ok := exchangeRates[*feeCurrency] + if !ok { + return nil, ErrNonWhitelistedFeeCurrency + } + return new(big.Int).Div(new(big.Int).Mul(goldAmount, exchangeRate.Num()), exchangeRate.Denom()), nil +} + +func getRate(exchangeRates common.ExchangeRates, feeCurrency *common.Address) (*big.Rat, error) { + if feeCurrency == nil { + return unitRate, nil + } + rate, ok := exchangeRates[*feeCurrency] + if !ok { + return nil, fmt.Errorf("fee currency not registered: %s", feeCurrency.Hex()) + } + return rate, nil +} + +// CompareValue compares values in different currencies (nil currency is native currency) +// returns -1 0 or 1 depending if val1 < val2, val1 == val2, or val1 > val2 respectively. +func CompareValue(exchangeRates common.ExchangeRates, val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) (int, error) { + // Short circuit if the fee currency is the same. + if feeCurrency1 == feeCurrency2 { + return val1.Cmp(val2), nil + } + + exchangeRate1, err := getRate(exchangeRates, feeCurrency1) + if err != nil { + return 0, err + } + exchangeRate2, err := getRate(exchangeRates, feeCurrency2) + if err != nil { + return 0, err + } + + // Below code block is basically evaluating this comparison: + // val1 * exchangeRate1.denominator / exchangeRate1.numerator < val2 * exchangeRate2.denominator / exchangeRate2.numerator + // It will transform that comparison to this, to remove having to deal with fractional values. + // val1 * exchangeRate1.denominator * exchangeRate2.numerator < val2 * exchangeRate2.denominator * exchangeRate1.numerator + leftSide := new(big.Int).Mul( + val1, + new(big.Int).Mul( + exchangeRate1.Denom(), + exchangeRate2.Num(), + ), + ) + rightSide := new(big.Int).Mul( + val2, + new(big.Int).Mul( + exchangeRate2.Denom(), + exchangeRate1.Num(), + ), + ) + + return leftSide.Cmp(rightSide), nil +} diff --git a/common/exchange/rates_test.go b/common/exchange/rates_test.go new file mode 100644 index 0000000000..3cda8eb530 --- /dev/null +++ b/common/exchange/rates_test.go @@ -0,0 +1,171 @@ +package exchange + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +var ( + currA = common.HexToAddress("0xA") + currB = common.HexToAddress("0xB") + currX = common.HexToAddress("0xF") + exchangeRates = common.ExchangeRates{ + currA: big.NewRat(47, 100), + currB: big.NewRat(45, 100), + } +) + +func TestCompareFees(t *testing.T) { + type args struct { + val1 *big.Int + feeCurrency1 *common.Address + val2 *big.Int + feeCurrency2 *common.Address + } + tests := []struct { + name string + args args + wantResult int + wantErr bool + }{ + // Native currency + { + name: "Same amount of native currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: nil, + }, + wantResult: 0, + }, { + name: "Different amounts of native currency 1", + args: args{ + val1: big.NewInt(2), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: nil, + }, + wantResult: 1, + }, { + name: "Different amounts of native currency 2", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(5), + feeCurrency2: nil, + }, + wantResult: -1, + }, + // Mixed currency + { + name: "Same amount of mixed currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + wantResult: -1, + }, { + name: "Different amounts of mixed currency 1", + args: args{ + val1: big.NewInt(100), + feeCurrency1: nil, + val2: big.NewInt(47), + feeCurrency2: &currA, + }, + wantResult: 0, + }, { + name: "Different amounts of mixed currency 2", + args: args{ + val1: big.NewInt(45), + feeCurrency1: &currB, + val2: big.NewInt(100), + feeCurrency2: nil, + }, + wantResult: 0, + }, + // Two fee currencies + { + name: "Same amount of same currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + wantResult: 0, + }, { + name: "Different amounts of same currency 1", + args: args{ + val1: big.NewInt(3), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + wantResult: 1, + }, { + name: "Different amounts of same currency 2", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(7), + feeCurrency2: &currA, + }, + wantResult: -1, + }, { + name: "Different amounts of different currencies 1", + args: args{ + val1: big.NewInt(47), + feeCurrency1: &currA, + val2: big.NewInt(45), + feeCurrency2: &currB, + }, + wantResult: 0, + }, { + name: "Different amounts of different currencies 2", + args: args{ + val1: big.NewInt(48), + feeCurrency1: &currA, + val2: big.NewInt(45), + feeCurrency2: &currB, + }, + wantResult: 1, + }, { + name: "Different amounts of different currencies 3", + args: args{ + val1: big.NewInt(47), + feeCurrency1: &currA, + val2: big.NewInt(46), + feeCurrency2: &currB, + }, + wantResult: -1, + }, + // Unregistered fee currency + { + name: "Different amounts of different currencies", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currX, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CompareValue(exchangeRates, tt.args.val1, tt.args.feeCurrency1, tt.args.val2, tt.args.feeCurrency2) + + if tt.wantErr && err == nil { + t.Error("Expected error in CompareValue()") + } + if got != tt.wantResult { + t.Errorf("CompareValue() = %v, want %v", got, tt.wantResult) + } + }) + } +} diff --git a/contracts/fee_currencies.go b/contracts/fee_currencies.go index b3d8e02fc0..417b1b082e 100644 --- a/contracts/fee_currencies.go +++ b/contracts/fee_currencies.go @@ -6,7 +6,6 @@ import ( "math/big" "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/contracts/celo/abigen" "github.com/ethereum/go-ethereum/core/vm" @@ -26,35 +25,8 @@ const ( var ( tmpAddress = common.HexToAddress("0xce106a5") - - // ErrNonWhitelistedFeeCurrency is returned if the currency specified to use for the fees - // isn't one of the currencies whitelisted for that purpose. - ErrNonWhitelistedFeeCurrency = errors.New("non-whitelisted fee currency address") ) -// GetBalanceOf returns an account's balance on a given ERC20 currency -func GetBalanceOf(caller bind.ContractCaller, accountOwner common.Address, contractAddress common.Address) (result *big.Int, err error) { - token, err := abigen.NewFeeCurrencyCaller(contractAddress, caller) - if err != nil { - return nil, fmt.Errorf("failed to access FeeCurrency: %w", err) - } - - balance, err := token.BalanceOf(&bind.CallOpts{}, accountOwner) - if err != nil { - return nil, err - } - - return balance, nil -} - -func ConvertGoldToCurrency(exchangeRates map[common.Address]*big.Rat, feeCurrency *common.Address, goldAmount *big.Int) (*big.Int, error) { - exchangeRate, ok := exchangeRates[*feeCurrency] - if !ok { - return nil, ErrNonWhitelistedFeeCurrency - } - return new(big.Int).Div(new(big.Int).Mul(goldAmount, exchangeRate.Num()), exchangeRate.Denom()), nil -} - // Debits transaction fees from the transaction sender and stores them in the temporary address func DebitFees(evm *vm.EVM, feeCurrency *common.Address, address common.Address, amount *big.Int) error { if amount.Cmp(big.NewInt(0)) == 0 { diff --git a/core/blockchain_celo_test.go b/core/blockchain_celo_test.go index ef7bba02da..ccfd5c4587 100644 --- a/core/blockchain_celo_test.go +++ b/core/blockchain_celo_test.go @@ -21,8 +21,8 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/exchange" "github.com/ethereum/go-ethereum/consensus/ethash" - fee_currencies "github.com/ethereum/go-ethereum/contracts" contracts "github.com/ethereum/go-ethereum/contracts/celo" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" @@ -101,15 +101,15 @@ func testNativeTransferWithFeeCurrency(t *testing.T, scheme string) { state, _ := chain.State() backend := CeloBackend{ - chainConfig: chain.chainConfig, - state: state, + ChainConfig: chain.chainConfig, + State: state, } - exchangeRates, err := getExchangeRates(&backend) + exchangeRates, err := backend.GetExchangeRates() if err != nil { t.Fatal("could not get exchange rates") } - baseFeeInFeeCurrency, _ := fee_currencies.ConvertGoldToCurrency(exchangeRates, &FeeCurrencyAddr, block.BaseFee()) - actual, _ := fee_currencies.GetBalanceOf(&backend, block.Coinbase(), FeeCurrencyAddr) + baseFeeInFeeCurrency, _ := exchange.ConvertGoldToCurrency(exchangeRates, &FeeCurrencyAddr, block.BaseFee()) + actual, _ := backend.GetBalanceERC20(block.Coinbase(), FeeCurrencyAddr) // 3: Ensure that miner received only the tx's tip. expected := new(big.Int).SetUint64(block.GasUsed() * block.Transactions()[0].GasTipCap().Uint64()) @@ -118,7 +118,7 @@ func testNativeTransferWithFeeCurrency(t *testing.T, scheme string) { } // 4: Ensure the tx sender paid for the gasUsed * (tip + block baseFee). - actual, _ = fee_currencies.GetBalanceOf(&backend, addr1, FeeCurrencyAddr) + actual, _ = backend.GetBalanceERC20(addr1, FeeCurrencyAddr) actual = new(big.Int).Sub(funds, actual) expected = new(big.Int).SetUint64(block.GasUsed() * (block.Transactions()[0].GasTipCap().Uint64() + baseFeeInFeeCurrency.Uint64())) if actual.Cmp(expected) != 0 { @@ -126,7 +126,7 @@ func testNativeTransferWithFeeCurrency(t *testing.T, scheme string) { } // 5: Check that base fee has been moved to the fee handler. - actual, _ = fee_currencies.GetBalanceOf(&backend, contracts.FeeHandlerAddress, FeeCurrencyAddr) + actual, _ = backend.GetBalanceERC20(contracts.FeeHandlerAddress, FeeCurrencyAddr) expected = new(big.Int).SetUint64(block.GasUsed() * baseFeeInFeeCurrency.Uint64()) if actual.Cmp(expected) != 0 { t.Fatalf("fee handler balance incorrect: expected %d, got %d", expected, actual) diff --git a/core/celo_backend.go b/core/celo_backend.go index 2ed050a0f5..9b416c90bb 100644 --- a/core/celo_backend.go +++ b/core/celo_backend.go @@ -2,25 +2,30 @@ package core import ( "context" + "fmt" "math/big" "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" + contracts "github.com/ethereum/go-ethereum/contracts/celo" + "github.com/ethereum/go-ethereum/contracts/celo/abigen" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" ) // CeloBackend provide a partial ContractBackend implementation, so that we can // access core contracts during block processing. type CeloBackend struct { - chainConfig *params.ChainConfig - state vm.StateDB + ChainConfig *params.ChainConfig + State vm.StateDB } // ContractCaller implementation func (b *CeloBackend) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) { - return b.state.GetCode(contract), nil + return b.State.GetCode(contract), nil } func (b *CeloBackend) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { @@ -43,9 +48,69 @@ func (b *CeloBackend) CallContract(ctx context.Context, call ethereum.CallMsg, b txCtx := vm.TxContext{} vmConfig := vm.Config{} - readOnlyStateDB := ReadOnlyStateDB{StateDB: b.state} - evm := vm.NewEVM(blockCtx, txCtx, &readOnlyStateDB, b.chainConfig, vmConfig) + readOnlyStateDB := ReadOnlyStateDB{StateDB: b.State} + evm := vm.NewEVM(blockCtx, txCtx, &readOnlyStateDB, b.ChainConfig, vmConfig) ret, _, err := evm.StaticCall(vm.AccountRef(evm.Origin), *call.To, call.Data, call.Gas) return ret, err } + +// GetBalanceERC20 returns an account's balance on a given ERC20 currency +func (b *CeloBackend) GetBalanceERC20(accountOwner common.Address, contractAddress common.Address) (result *big.Int, err error) { + token, err := abigen.NewFeeCurrencyCaller(contractAddress, b) + if err != nil { + return nil, fmt.Errorf("failed to access FeeCurrency: %w", err) + } + + balance, err := token.BalanceOf(&bind.CallOpts{}, accountOwner) + if err != nil { + return nil, err + } + + return balance, nil +} + +// GetFeeBalance returns the account's balance from the specified feeCurrency +// (if feeCurrency is nil or ZeroAddress, native currency balance is returned). +func (b *CeloBackend) GetFeeBalance(account common.Address, feeCurrency *common.Address) *big.Int { + if feeCurrency == nil || *feeCurrency == common.ZeroAddress { + return b.State.GetBalance(account) + } + balance, err := b.GetBalanceERC20(account, *feeCurrency) + if err != nil { + log.Error("Error while trying to get ERC20 balance:", "cause", err, "contract", feeCurrency.Hex(), "account", account.Hex()) + } + return balance +} + +// GetExchangeRates returns the exchange rates for all gas currencies from CELO +func (b *CeloBackend) GetExchangeRates() (common.ExchangeRates, error) { + exchangeRates := map[common.Address]*big.Rat{} + whitelist, err := abigen.NewFeeCurrencyWhitelistCaller(contracts.FeeCurrencyWhitelistAddress, b) + if err != nil { + return exchangeRates, fmt.Errorf("Failed to access FeeCurrencyWhitelist: %w", err) + } + oracle, err := abigen.NewSortedOraclesCaller(contracts.SortedOraclesAddress, b) + if err != nil { + return exchangeRates, fmt.Errorf("Failed to access SortedOracle: %w", err) + } + + whitelistedTokens, err := whitelist.GetWhitelist(&bind.CallOpts{}) + if err != nil { + return exchangeRates, fmt.Errorf("Failed to get whitelisted tokens: %w", err) + } + for _, tokenAddress := range whitelistedTokens { + numerator, denominator, err := oracle.MedianRate(&bind.CallOpts{}, tokenAddress) + if err != nil { + log.Error("Failed to get medianRate for gas currency!", "err", err, "tokenAddress", tokenAddress.Hex()) + continue + } + if denominator.Sign() == 0 { + log.Error("Bad exchange rate for fee currency", "tokenAddress", tokenAddress.Hex(), "numerator", numerator, "denominator", denominator) + continue + } + exchangeRates[tokenAddress] = big.NewRat(numerator.Int64(), denominator.Int64()) + } + + return exchangeRates, nil +} diff --git a/core/celo_evm.go b/core/celo_evm.go index 7b78f1e75a..e9fc051d6e 100644 --- a/core/celo_evm.go +++ b/core/celo_evm.go @@ -15,7 +15,7 @@ import ( ) // Returns the exchange rates for all gas currencies from CELO -func getExchangeRates(caller *CeloBackend) (map[common.Address]*big.Rat, error) { +func GetExchangeRates(caller bind.ContractCaller) (common.ExchangeRates, error) { exchangeRates := map[common.Address]*big.Rat{} whitelist, err := abigen.NewFeeCurrencyWhitelistCaller(contracts.FeeCurrencyWhitelistAddress, caller) if err != nil { @@ -55,7 +55,7 @@ func setCeloFieldsInBlockContext(blockContext *vm.BlockContext, header *types.He // Add fee currency exchange rates var err error - blockContext.ExchangeRates, err = getExchangeRates(caller) + blockContext.ExchangeRates, err = caller.GetExchangeRates() if err != nil { log.Error("Error fetching exchange rates!", "err", err) } diff --git a/core/state_processor.go b/core/state_processor.go index e686942b29..dbb0921eab 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -22,9 +22,9 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/exchange" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/misc" - fee_currencies "github.com/ethereum/go-ethereum/contracts" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" @@ -156,7 +156,7 @@ func applyTransaction(msg *Message, config *params.ChainConfig, gp *GasPool, sta if tx.Type() == types.CeloDynamicFeeTxType { alternativeBaseFee := evm.Context.BaseFee if msg.FeeCurrency != nil { - alternativeBaseFee, err = fee_currencies.ConvertGoldToCurrency(evm.Context.ExchangeRates, msg.FeeCurrency, evm.Context.BaseFee) + alternativeBaseFee, err = exchange.ConvertGoldToCurrency(evm.Context.ExchangeRates, msg.FeeCurrency, evm.Context.BaseFee) if err != nil { return nil, err } diff --git a/core/state_transition.go b/core/state_transition.go index c055dad624..020f09cb14 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -23,6 +23,7 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/exchange" cmath "github.com/ethereum/go-ethereum/common/math" fee_currencies "github.com/ethereum/go-ethereum/contracts" contracts "github.com/ethereum/go-ethereum/contracts/celo" @@ -209,7 +210,7 @@ func TransactionToMessage(tx *types.Transaction, s types.Signer, baseFee *big.In if baseFee != nil { if msg.FeeCurrency != nil { var err error - baseFee, err = fee_currencies.ConvertGoldToCurrency(exchangeRates, msg.FeeCurrency, baseFee) + baseFee, err = exchange.ConvertGoldToCurrency(exchangeRates, msg.FeeCurrency, baseFee) if err != nil { return nil, err } @@ -293,7 +294,7 @@ func (st *StateTransition) buyGas() error { // L1 data fee needs to be converted in fee currency if st.msg.FeeCurrency != nil && l1Cost != nil { // Existence of the fee currency has been checked in `preCheck` - l1Cost, _ = fee_currencies.ConvertGoldToCurrency(st.evm.Context.ExchangeRates, st.msg.FeeCurrency, l1Cost) + l1Cost, _ = exchange.ConvertGoldToCurrency(st.evm.Context.ExchangeRates, st.msg.FeeCurrency, l1Cost) } } if l1Cost != nil { @@ -345,10 +346,10 @@ func (st *StateTransition) canPayFee(checkAmount *big.Int) error { } } else { backend := &CeloBackend{ - chainConfig: st.evm.ChainConfig(), - state: st.state, + ChainConfig: st.evm.ChainConfig(), + State: st.state, } - balance, err := fee_currencies.GetBalanceOf(backend, st.msg.From, *st.msg.FeeCurrency) + balance, err := backend.GetBalanceERC20(st.msg.From, *st.msg.FeeCurrency) if err != nil { return err } @@ -440,10 +441,10 @@ func (st *StateTransition) preCheck() error { if !st.evm.ChainConfig().IsCel2(st.evm.Context.Time) { return ErrCel2NotEnabled } else { - isWhiteListed := st.evm.Context.IsCurrencyWhitelisted(msg.FeeCurrency) + isWhiteListed := common.IsCurrencyWhitelisted(st.evm.Context.ExchangeRates, msg.FeeCurrency) if !isWhiteListed { log.Trace("fee currency not whitelisted", "fee currency address", msg.FeeCurrency) - return fee_currencies.ErrNonWhitelistedFeeCurrency + return exchange.ErrNonWhitelistedFeeCurrency } } } @@ -703,7 +704,7 @@ func (st *StateTransition) distributeTxFees() error { } } else { if l1Cost != nil { - l1Cost, _ = fee_currencies.ConvertGoldToCurrency(st.evm.Context.ExchangeRates, feeCurrency, l1Cost) + l1Cost, _ = exchange.ConvertGoldToCurrency(st.evm.Context.ExchangeRates, feeCurrency, l1Cost) } if err := fee_currencies.CreditFees(st.evm, feeCurrency, from, st.evm.Context.Coinbase, feeHandlerAddress, params.OptimismL1FeeRecipient, refund, tipTxFee, baseTxFee, l1Cost); err != nil { log.Error("Error crediting", "from", from, "coinbase", st.evm.Context.Coinbase, "feeHandler", feeHandlerAddress) @@ -725,7 +726,7 @@ func (st *StateTransition) calculateBaseFee() *big.Int { if st.msg.FeeCurrency != nil { // Existence of the fee currency has been checked in `preCheck` - baseFee, _ = fee_currencies.ConvertGoldToCurrency(st.evm.Context.ExchangeRates, st.msg.FeeCurrency, baseFee) + baseFee, _ = exchange.ConvertGoldToCurrency(st.evm.Context.ExchangeRates, st.msg.FeeCurrency, baseFee) } return baseFee diff --git a/core/txpool/blobpool/blobpool.go b/core/txpool/blobpool/blobpool.go index 9966d5528a..99a5bc81a5 100644 --- a/core/txpool/blobpool/blobpool.go +++ b/core/txpool/blobpool/blobpool.go @@ -316,8 +316,9 @@ type BlobPool struct { lock sync.RWMutex // Mutex protecting the pool during reorg handling - // Celo - feeCurrencyValidator txpool.FeeCurrencyValidator + // Celo specific + celoBackend *core.CeloBackend // For fee currency balances & exchange rate calculation + currentRates common.ExchangeRates // current exchange rates for fee currencies } // New creates a new blob transaction pool to gather, sort and filter inbound @@ -334,8 +335,6 @@ func New(config Config, chain BlockChain) *BlobPool { lookup: make(map[common.Hash]uint64), index: make(map[common.Address][]*blobTxMeta), spent: make(map[common.Address]*uint256.Int), - - feeCurrencyValidator: txpool.NewFeeCurrencyValidator(), } } @@ -375,6 +374,7 @@ func (p *BlobPool) Init(gasTip *big.Int, head *types.Header, reserve txpool.Addr return err } p.head, p.state = head, state + p.recreateCeloProperties() // Index all transactions on disk and delete anything inprocessable var fails []uint64 @@ -761,6 +761,7 @@ func (p *BlobPool) Reset(oldHead, newHead *types.Header) { } p.head = newHead p.state = statedb + p.recreateCeloProperties() // Run the reorg between the old and new head and figure out which accounts // need to be rechecked and which transactions need to be readded @@ -1050,8 +1051,7 @@ func (p *BlobPool) validateTx(tx *types.Transaction) error { MaxSize: txMaxSize, MinTip: p.gasTip.ToBig(), } - var fcv txpool.FeeCurrencyValidator = nil // TODO: create with proper value - if err := txpool.CeloValidateTransaction(tx, p.head, p.signer, baseOpts, p.state, fcv); err != nil { + if err := txpool.CeloValidateTransaction(tx, p.head, p.signer, baseOpts, p.currentRates); err != nil { return err } // Ensure the transaction adheres to the stateful pool filters (nonce, balance) @@ -1084,15 +1084,12 @@ func (p *BlobPool) validateTx(tx *types.Transaction) error { } return nil }, + ExistingBalance: func(addr common.Address, feeCurrency *common.Address) *big.Int { + return p.celoBackend.GetFeeBalance(addr, feeCurrency) + }, } - // Adapt to celo validation options - celoOpts := &txpool.CeloValidationOptionsWithState{ - ValidationOptionsWithState: *stateOpts, - FeeCurrencyValidator: p.feeCurrencyValidator, - } - - if err := txpool.ValidateTransactionWithState(tx, p.signer, celoOpts); err != nil { + if err := txpool.ValidateTransactionWithState(tx, p.signer, stateOpts); err != nil { return err } // If the transaction replaces an existing one, ensure that price bumps are diff --git a/core/txpool/blobpool/celo_blobpool.go b/core/txpool/blobpool/celo_blobpool.go new file mode 100644 index 0000000000..85e7572dbd --- /dev/null +++ b/core/txpool/blobpool/celo_blobpool.go @@ -0,0 +1,18 @@ +package blobpool + +import ( + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/log" +) + +func (pool *BlobPool) recreateCeloProperties() { + pool.celoBackend = &core.CeloBackend{ + ChainConfig: pool.chain.Config(), + State: pool.state, + } + currentRates, err := pool.celoBackend.GetExchangeRates() + if err != nil { + log.Error("Error trying to get exchange rates in txpool.", "cause", err) + } + pool.currentRates = currentRates +} diff --git a/core/txpool/celo_validation.go b/core/txpool/celo_validation.go index 5c0f7b27ac..98aec4c749 100644 --- a/core/txpool/celo_validation.go +++ b/core/txpool/celo_validation.go @@ -5,41 +5,12 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/params" ) var NonWhitelistedFeeCurrencyError = errors.New("Fee currency given is not whitelisted at current block") -// FeeCurrencyValidator validates currency whitelisted status at the specified -// block number. -type FeeCurrencyValidator interface { - IsWhitelisted(st *state.StateDB, feeCurrency *common.Address) bool - // Balance returns the feeCurrency balance of the address specified, in the given state. - // If feeCurrency is nil, the native currency balance has to be returned. - Balance(st *state.StateDB, address common.Address, feeCurrency *common.Address) *big.Int -} - -func NewFeeCurrencyValidator() FeeCurrencyValidator { - return &feeval{} -} - -type feeval struct { -} - -func (f *feeval) IsWhitelisted(st *state.StateDB, feeCurrency *common.Address) bool { - // TODO: implement proper validation for all currencies - // Hardcoded for the moment - return true - //return feeCurrency == nil -} - -func (f *feeval) Balance(st *state.StateDB, address common.Address, feeCurrency *common.Address) *big.Int { - // TODO: implement proper balance retrieval for fee currencies - return st.GetBalance(address) -} - // AcceptSet is a set of accepted transaction types for a transaction subpool. type AcceptSet = map[uint8]struct{} @@ -77,34 +48,13 @@ func (cvo *CeloValidationOptions) Accepts(txType uint8) bool { // This check is public to allow different transaction pools to check the basic // rules without duplicating code and running the risk of missed updates. func CeloValidateTransaction(tx *types.Transaction, head *types.Header, - signer types.Signer, opts *CeloValidationOptions, st *state.StateDB, fcv FeeCurrencyValidator) error { + signer types.Signer, opts *CeloValidationOptions, rates common.ExchangeRates) error { if err := ValidateTransaction(tx, head, signer, opts); err != nil { return err } - if IsFeeCurrencyTx(tx) { - if !fcv.IsWhitelisted(st, tx.FeeCurrency()) { - return NonWhitelistedFeeCurrencyError - } + if !common.IsCurrencyWhitelisted(rates, tx.FeeCurrency()) { + return NonWhitelistedFeeCurrencyError } - return nil -} - -// IsFeeCurrencyTxType returns true if and only if the transaction type -// given can handle custom gas fee currencies. -func IsFeeCurrencyTxType(t uint8) bool { - return t == types.CeloDynamicFeeTxType -} -// IsFeeCurrencyTx returns true if this transaction specifies a custom -// gas fee currency. -func IsFeeCurrencyTx(tx *types.Transaction) bool { - return IsFeeCurrencyTxType(tx.Type()) && tx.FeeCurrency() != nil -} - -// See: txpool.ValidationOptionsWithState -type CeloValidationOptionsWithState struct { - ValidationOptionsWithState - - // FeeCurrencyValidator allows for balance check of non native fee currencies. - FeeCurrencyValidator FeeCurrencyValidator + return nil } diff --git a/core/txpool/legacypool/celo_legacypool.go b/core/txpool/legacypool/celo_legacypool.go new file mode 100644 index 0000000000..a89dc75b70 --- /dev/null +++ b/core/txpool/legacypool/celo_legacypool.go @@ -0,0 +1,43 @@ +package legacypool + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" +) + +// filter Filters transactions from the given list, according to remaining balance (per currency) +// and gasLimit. Returns drops and invalid txs. +func (pool *LegacyPool) filter(list *list, addr common.Address, gasLimit uint64) (types.Transactions, types.Transactions) { + // CELO: drop all transactions that no longer have a whitelisted currency + dropsWhitelist, invalidsWhitelist := list.FilterWhitelisted(pool.currentRates) + // Check from which currencies we need to get balances + currenciesInList := list.FeeCurrencies() + drops, invalids := list.Filter(pool.getBalances(addr, currenciesInList), gasLimit) + totalDrops := append(dropsWhitelist, drops...) + totalInvalids := append(invalidsWhitelist, invalids...) + return totalDrops, totalInvalids +} + +func (pool *LegacyPool) getBalances(address common.Address, currencies []common.Address) map[common.Address]*big.Int { + balances := make(map[common.Address]*big.Int, len(currencies)) + for _, curr := range currencies { + balances[curr] = pool.celoBackend.GetFeeBalance(address, &curr) + } + return balances +} + +func (pool *LegacyPool) recreateCeloProperties() { + pool.celoBackend = &core.CeloBackend{ + ChainConfig: pool.chainconfig, + State: pool.currentState, + } + currentRates, err := pool.celoBackend.GetExchangeRates() + if err != nil { + log.Error("Error trying to get exchange rates in txpool.", "cause", err) + } + pool.currentRates = currentRates +} diff --git a/core/txpool/legacypool/celo_list.go b/core/txpool/legacypool/celo_list.go new file mode 100644 index 0000000000..251849a54a --- /dev/null +++ b/core/txpool/legacypool/celo_list.go @@ -0,0 +1,158 @@ +package legacypool + +import ( + "math" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/exchange" + "github.com/ethereum/go-ethereum/core/types" +) + +func (l *list) FilterWhitelisted(rates common.ExchangeRates) (types.Transactions, types.Transactions) { + removed := l.txs.Filter(func(tx *types.Transaction) bool { + return !common.IsCurrencyWhitelisted(rates, tx.FeeCurrency()) + }) + + if len(removed) == 0 { + return nil, nil + } + + invalid := l.dropInvalidsAfterRemovalAndReheap(removed) + l.subTotalCost(removed) + l.subTotalCost(invalid) + return removed, invalid +} + +func (l *list) dropInvalidsAfterRemovalAndReheap(removed types.Transactions) types.Transactions { + var invalids types.Transactions + // If the list was strict, filter anything above the lowest nonce + // Note that the 'invalid' txs have no intersection with the 'removed' txs + if l.strict { + lowest := uint64(math.MaxUint64) + for _, tx := range removed { + if nonce := tx.Nonce(); lowest > nonce { + lowest = nonce + } + } + invalids = l.txs.filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) + } + l.txs.reheap() + return invalids +} + +func (l *list) FeeCurrencies() []common.Address { + currencySet := make(map[common.Address]interface{}) + for _, tx := range l.txs.items { + // native currency (nil) represented as Zero address + currencySet[getCurrencyKey(tx.FeeCurrency())] = struct{}{} + } + currencies := make([]common.Address, 0, len(currencySet)) + for curr := range currencySet { + currencies = append(currencies, curr) + } + return currencies +} + +func getCurrencyKey(feeCurrency *common.Address) common.Address { + if feeCurrency == nil { + return common.ZeroAddress + } + return *feeCurrency +} + +func (c *list) totalCostVar(feeCurrency *common.Address) *big.Int { + key := getCurrencyKey(feeCurrency) + if tc, ok := c.totalCost[key]; ok { + return tc + } + newTc := big.NewInt(0) + c.totalCost[key] = newTc + return newTc +} + +func (c *list) TotalCostFor(feeCurrency *common.Address) *big.Int { + if tc, ok := c.totalCost[getCurrencyKey(feeCurrency)]; ok { + return new(big.Int).Set(tc) + } + return big.NewInt(0) +} + +func (c *list) costCapFor(feeCurrency *common.Address) *big.Int { + if tc, ok := c.costCap[getCurrencyKey(feeCurrency)]; ok { + return tc + } + return big.NewInt(0) +} + +func (c *list) updateCostCapFor(feeCurrency *common.Address, possibleCap *big.Int) { + currentCap := c.costCapFor(feeCurrency) + if possibleCap.Cmp(currentCap) > 0 { + c.costCap[getCurrencyKey(feeCurrency)] = possibleCap + } +} + +func (c *list) costCapsLowerThan(costLimits map[common.Address]*big.Int) bool { + for curr, cap := range c.costCap { + limit, ok := costLimits[curr] + if !ok || limit == nil { + // If there's no limit for the currency we can assume the limit is zero + return cap.Cmp(common.Big0) == 0 + } + if cap.Cmp(limit) > 0 { + return false + } + } + return true +} + +func (c *list) setCapsTo(caps map[common.Address]*big.Int) { + c.costCap = make(map[common.Address]*big.Int) + for curr, cap := range caps { + if cap == nil || cap.Cmp(common.Big0) == 0 { + c.costCap[curr] = big.NewInt(0) + } else { + c.costCap[curr] = new(big.Int).Set(cap) + } + } +} + +type TxComparator func(a, b *types.Transaction, baseFee *big.Int) int + +func (p *pricedList) compareWithRates(a, b *types.Transaction, goldBaseFee *big.Int) int { + if goldBaseFee != nil { + tipA := effectiveTip(p.rates, goldBaseFee, a) + tipB := effectiveTip(p.rates, goldBaseFee, b) + result, _ := exchange.CompareValue(p.rates, tipA, a.FeeCurrency(), tipB, b.FeeCurrency()) + return result + } + + // Compare fee caps if baseFee is not specified or effective tips are equal + feeA := a.GasFeeCap() + feeB := b.GasFeeCap() + c, _ := exchange.CompareValue(p.rates, feeA, a.FeeCurrency(), feeB, b.FeeCurrency()) + if c != 0 { + return c + } + + // Compare tips if effective tips and fee caps are equal + tipCapA := a.GasTipCap() + tipCapB := b.GasTipCap() + result, _ := exchange.CompareValue(p.rates, tipCapA, a.FeeCurrency(), tipCapB, b.FeeCurrency()) + return result +} + +func baseFeeInCurrency(rates common.ExchangeRates, goldBaseFee *big.Int, feeCurrency *common.Address) *big.Int { + // can ignore the whitelist error since txs with non whitelisted currencies + // are pruned + baseFee, _ := exchange.ConvertGoldToCurrency(rates, feeCurrency, goldBaseFee) + return baseFee +} + +func effectiveTip(rates common.ExchangeRates, goldBaseFee *big.Int, tx *types.Transaction) *big.Int { + if tx.FeeCurrency() == nil { + return tx.EffectiveGasTipValue(goldBaseFee) + } + baseFee := baseFeeInCurrency(rates, goldBaseFee, tx.FeeCurrency()) + return tx.EffectiveGasTipValue(baseFee) +} diff --git a/core/txpool/legacypool/celo_list_test.go b/core/txpool/legacypool/celo_list_test.go new file mode 100644 index 0000000000..ce0a60dda4 --- /dev/null +++ b/core/txpool/legacypool/celo_list_test.go @@ -0,0 +1,198 @@ +package legacypool + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/assert" +) + +func txC(nonce int, feeCap int, tipCap int, gas int, currency *common.Address) *types.Transaction { + return types.NewTx(&types.CeloDynamicFeeTx{ + GasFeeCap: big.NewInt(int64(feeCap)), + GasTipCap: big.NewInt(int64(tipCap)), + FeeCurrency: currency, + Gas: uint64(gas), + Nonce: uint64(nonce), + }) +} + +func TestListFeeCost(t *testing.T) { + curr1 := common.HexToAddress("0002") + curr2 := common.HexToAddress("0004") + curr3 := common.HexToAddress("0006") + rates := common.ExchangeRates{ + curr1: big.NewRat(2, 1), + curr2: big.NewRat(4, 1), + curr3: big.NewRat(6, 1), + } + // Insert the transactions in a random order + list := newList(false) + + list.Add(txC(7, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + assert.Equal(t, int64(10000), list.TotalCostFor(&curr1).Int64()) + + toBeRemoved := txC(8, 2, 1, 15000, &curr2) + list.Add(toBeRemoved, DefaultConfig.PriceBump, nil, rates) + assert.Equal(t, int64(30000), list.TotalCostFor(&curr2).Int64()) + assert.Equal(t, int64(10000), list.TotalCostFor(&curr1).Int64()) + + list.Add(txC(9, 3, 2, 5000, &curr3), DefaultConfig.PriceBump, nil, rates) + assert.Equal(t, int64(15000), list.TotalCostFor(&curr3).Int64()) + assert.Equal(t, int64(30000), list.TotalCostFor(&curr2).Int64()) + assert.Equal(t, int64(10000), list.TotalCostFor(&curr1).Int64()) + + // Add another tx from curr1, check it adds properly + list.Add(txC(10, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + assert.Equal(t, int64(15000), list.TotalCostFor(&curr3).Int64()) + assert.Equal(t, int64(30000), list.TotalCostFor(&curr2).Int64()) + assert.Equal(t, int64(20000), list.TotalCostFor(&curr1).Int64()) + + // Remove a tx from curr2, check it subtracts properly + removed, _ := list.Remove(toBeRemoved) + assert.True(t, removed) + + assert.Equal(t, int64(15000), list.TotalCostFor(&curr3).Int64()) + assert.Equal(t, int64(0), list.TotalCostFor(&curr2).Int64()) + assert.Equal(t, int64(20000), list.TotalCostFor(&curr1).Int64()) +} + +func TestFilterWhitelisted(t *testing.T) { + curr1 := common.HexToAddress("0002") + curr2 := common.HexToAddress("0004") + curr3 := common.HexToAddress("0006") + rates := common.ExchangeRates{ + curr1: big.NewRat(2, 1), + curr2: big.NewRat(4, 1), + curr3: big.NewRat(6, 1), + } + + list := newList(false) + list.Add(txC(7, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + toBeRemoved := txC(8, 2, 1, 15000, &curr2) + list.Add(toBeRemoved, DefaultConfig.PriceBump, nil, rates) + list.Add(txC(9, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + assert.Equal(t, int64(30000), list.TotalCostFor(&curr2).Int64()) + + removed, invalids := list.FilterWhitelisted(common.ExchangeRates{curr1: nil, curr3: nil}) + assert.Len(t, removed, 1) + assert.Len(t, invalids, 0) + assert.Equal(t, removed[0], toBeRemoved) + assert.Equal(t, int64(0), list.TotalCostFor(&curr2).Int64()) +} + +func TestFilterWhitelistedStrict(t *testing.T) { + curr1 := common.HexToAddress("0002") + curr2 := common.HexToAddress("0004") + curr3 := common.HexToAddress("0006") + rates := common.ExchangeRates{ + curr1: big.NewRat(2, 1), + curr2: big.NewRat(4, 1), + curr3: big.NewRat(6, 1), + } + + list := newList(true) + list.Add(txC(7, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + toBeRemoved := txC(8, 2, 1, 15000, &curr2) + list.Add(toBeRemoved, DefaultConfig.PriceBump, nil, rates) + toBeInvalid := txC(9, 1, 1, 10000, &curr3) + list.Add(toBeInvalid, DefaultConfig.PriceBump, nil, rates) + + removed, invalids := list.FilterWhitelisted(common.ExchangeRates{curr1: nil, curr3: nil}) + assert.Len(t, removed, 1) + assert.Len(t, invalids, 1) + assert.Equal(t, removed[0], toBeRemoved) + assert.Equal(t, invalids[0], toBeInvalid) + assert.Equal(t, int64(0), list.TotalCostFor(&curr2).Int64()) + assert.Equal(t, int64(0), list.TotalCostFor(&curr3).Int64()) + assert.Equal(t, int64(10000), list.TotalCostFor(&curr1).Int64()) +} + +func TestFilterBalance(t *testing.T) { + curr1 := common.HexToAddress("0002") + curr2 := common.HexToAddress("0004") + curr3 := common.HexToAddress("0006") + rates := common.ExchangeRates{ + curr1: big.NewRat(2, 1), + curr2: big.NewRat(4, 1), + curr3: big.NewRat(6, 1), + } + + list := newList(false) + // each tx costs 10000 in each currency + list.Add(txC(7, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + toBeRemoved := txC(8, 1, 1, 10000, &curr2) + list.Add(toBeRemoved, DefaultConfig.PriceBump, nil, rates) + list.Add(txC(9, 1, 1, 10000, &curr3), DefaultConfig.PriceBump, nil, rates) + + removed, invalids := list.Filter(map[common.Address]*big.Int{ + curr1: big.NewInt(10000), + curr2: big.NewInt(9999), + curr3: big.NewInt(10000), + }, 15000) + assert.Len(t, removed, 1) + assert.Len(t, invalids, 0) + assert.Equal(t, removed[0], toBeRemoved) + assert.Equal(t, int64(0), list.TotalCostFor(&curr2).Int64()) +} + +func TestFilterBalanceStrict(t *testing.T) { + curr1 := common.HexToAddress("0002") + curr2 := common.HexToAddress("0004") + curr3 := common.HexToAddress("0006") + rates := common.ExchangeRates{ + curr1: big.NewRat(2, 1), + curr2: big.NewRat(4, 1), + curr3: big.NewRat(6, 1), + } + + list := newList(true) + // each tx costs 10000 in each currency + list.Add(txC(7, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + toBeRemoved := txC(8, 1, 1, 10000, &curr2) + list.Add(toBeRemoved, DefaultConfig.PriceBump, nil, rates) + toBeInvalid := txC(9, 1, 1, 10000, &curr3) + list.Add(toBeInvalid, DefaultConfig.PriceBump, nil, rates) + + removed, invalids := list.Filter(map[common.Address]*big.Int{ + curr1: big.NewInt(10001), + curr2: big.NewInt(9999), + curr3: big.NewInt(10001), + }, 15000) + assert.Len(t, removed, 1) + assert.Len(t, invalids, 1) + assert.Equal(t, removed[0], toBeRemoved) + assert.Equal(t, invalids[0], toBeInvalid) + assert.Equal(t, int64(0), list.TotalCostFor(&curr2).Int64()) + assert.Equal(t, int64(0), list.TotalCostFor(&curr3).Int64()) +} + +func TestFilterBalanceGasLimit(t *testing.T) { + curr1 := common.HexToAddress("0002") + curr2 := common.HexToAddress("0004") + curr3 := common.HexToAddress("0006") + rates := common.ExchangeRates{ + curr1: big.NewRat(2, 1), + curr2: big.NewRat(4, 1), + curr3: big.NewRat(6, 1), + } + + list := newList(false) + // each tx costs 10000 in each currency + list.Add(txC(7, 1, 1, 10000, &curr1), DefaultConfig.PriceBump, nil, rates) + toBeRemoved := txC(8, 1, 1, 10001, &curr2) + list.Add(toBeRemoved, DefaultConfig.PriceBump, nil, rates) + list.Add(txC(9, 1, 1, 10000, &curr3), DefaultConfig.PriceBump, nil, rates) + + removed, invalids := list.Filter(map[common.Address]*big.Int{ + curr1: big.NewInt(20000), + curr2: big.NewInt(20000), + curr3: big.NewInt(20000), + }, 10000) + assert.Len(t, removed, 1) + assert.Len(t, invalids, 0) + assert.Equal(t, removed[0], toBeRemoved) + assert.Equal(t, int64(0), list.TotalCostFor(&curr2).Int64()) +} diff --git a/core/txpool/legacypool/legacypool.go b/core/txpool/legacypool/legacypool.go index c3ddd22d2c..9c07feba72 100644 --- a/core/txpool/legacypool/legacypool.go +++ b/core/txpool/legacypool/legacypool.go @@ -241,8 +241,9 @@ type LegacyPool struct { l1CostFn txpool.L1CostFunc // To apply L1 costs as rollup, optional field, may be nil. - // Celo - feeCurrencyValidator txpool.FeeCurrencyValidator + // Celo specific + celoBackend *core.CeloBackend // For fee currency balances & exchange rate calculation + currentRates common.ExchangeRates // current exchange rates for fee currencies } type txpoolResetRequest struct { @@ -271,9 +272,6 @@ func New(config Config, chain BlockChain) *LegacyPool { reorgDoneCh: make(chan chan struct{}), reorgShutdownCh: make(chan struct{}), initDoneCh: make(chan struct{}), - - // CELO fields - feeCurrencyValidator: txpool.NewFeeCurrencyValidator(), } pool.locals = newAccountSet(pool.signer) for _, addr := range config.Locals { @@ -323,6 +321,7 @@ func (pool *LegacyPool) Init(gasTip *big.Int, head *types.Header, reserve txpool pool.currentHead.Store(head) pool.currentState = statedb pool.pendingNonces = newNoncer(statedb) + pool.recreateCeloProperties() // Start the reorg loop early, so it can handle requests generated during // journal loading. @@ -637,7 +636,7 @@ func (pool *LegacyPool) validateTxBasics(tx *types.Transaction, local bool) erro if local { opts.MinTip = new(big.Int) } - if err := txpool.CeloValidateTransaction(tx, pool.currentHead.Load(), pool.signer, opts, pool.currentState, pool.feeCurrencyValidator); err != nil { + if err := txpool.CeloValidateTransaction(tx, pool.currentHead.Load(), pool.signer, opts, pool.currentRates); err != nil { return err } return nil @@ -662,7 +661,7 @@ func (pool *LegacyPool) validateTx(tx *types.Transaction, local bool) error { }, ExistingExpenditure: func(addr common.Address) *big.Int { if list := pool.pending[addr]; list != nil { - return list.totalcost + return list.TotalCostFor(tx.FeeCurrency()) } return new(big.Int) }, @@ -681,15 +680,12 @@ func (pool *LegacyPool) validateTx(tx *types.Transaction, local bool) error { return nil }, L1CostFn: pool.l1CostFn, + ExistingBalance: func(addr common.Address, feeCurrency *common.Address) *big.Int { + return pool.celoBackend.GetFeeBalance(addr, feeCurrency) + }, } - // Adapt to celo validation options - celoOpts := &txpool.CeloValidationOptionsWithState{ - ValidationOptionsWithState: *opts, - FeeCurrencyValidator: pool.feeCurrencyValidator, - } - - if err := txpool.ValidateTransactionWithState(tx, pool.signer, celoOpts); err != nil { + if err := txpool.ValidateTransactionWithState(tx, pool.signer, opts); err != nil { return err } return nil @@ -810,7 +806,7 @@ func (pool *LegacyPool) add(tx *types.Transaction, local bool) (replaced bool, e // Try to replace an existing transaction in the pending pool if list := pool.pending[from]; list != nil && list.Contains(tx.Nonce()) { // Nonce already pending, check if required price bump is met - inserted, old := list.Add(tx, pool.config.PriceBump, pool.l1CostFn) + inserted, old := list.Add(tx, pool.config.PriceBump, pool.l1CostFn, pool.currentRates) if !inserted { pendingDiscardMeter.Mark(1) return false, txpool.ErrReplaceUnderpriced @@ -884,7 +880,7 @@ func (pool *LegacyPool) enqueueTx(hash common.Hash, tx *types.Transaction, local if pool.queue[from] == nil { pool.queue[from] = newList(false) } - inserted, old := pool.queue[from].Add(tx, pool.config.PriceBump, pool.l1CostFn) + inserted, old := pool.queue[from].Add(tx, pool.config.PriceBump, pool.l1CostFn, pool.currentRates) if !inserted { // An older transaction was better, discard this queuedDiscardMeter.Mark(1) @@ -938,7 +934,7 @@ func (pool *LegacyPool) promoteTx(addr common.Address, hash common.Hash, tx *typ } list := pool.pending[addr] - inserted, old := list.Add(tx, pool.config.PriceBump, pool.l1CostFn) + inserted, old := list.Add(tx, pool.config.PriceBump, pool.l1CostFn, pool.currentRates) if !inserted { // An older transaction was better, discard this pool.all.Remove(hash) @@ -1332,7 +1328,7 @@ func (pool *LegacyPool) runReorg(done chan struct{}, reset *txpoolResetRequest, if reset.newHead != nil { if pool.chainconfig.IsLondon(new(big.Int).Add(reset.newHead.Number, big.NewInt(1))) { pendingBaseFee := eip1559.CalcBaseFee(pool.chainconfig, reset.newHead, reset.newHead.Time+1) - pool.priced.SetBaseFee(pendingBaseFee) + pool.priced.SetBaseFeeAndRates(pendingBaseFee, pool.currentRates) } else { pool.priced.Reheap() } @@ -1462,6 +1458,7 @@ func (pool *LegacyPool) reset(oldHead, newHead *types.Header) { pool.currentHead.Store(newHead) pool.currentState = statedb pool.pendingNonces = newNoncer(statedb) + pool.recreateCeloProperties() costFn := types.NewL1CostFunc(pool.chainconfig, statedb) pool.l1CostFn = func(dataGas types.RollupGasData) *big.Int { @@ -1495,16 +1492,16 @@ func (pool *LegacyPool) promoteExecutables(accounts []common.Address) []*types.T pool.all.Remove(hash) } log.Trace("Removed old queued transactions", "count", len(forwards)) - balance := pool.currentState.GetBalance(addr) - if !list.Empty() && pool.l1CostFn != nil { - // Reduce the cost-cap by L1 rollup cost of the first tx if necessary. Other txs will get filtered out afterwards. - el := list.txs.FirstElement() - if l1Cost := pool.l1CostFn(el.RollupDataGas()); l1Cost != nil { - balance = new(big.Int).Sub(balance, l1Cost) // negative big int is fine - } - } // Drop all transactions that are too costly (low balance or out of gas) - drops, _ := list.Filter(balance, gasLimit) + // var l1Cost *big.Int + // TODO: manage l1cost in list + // if !list.Empty() && pool.l1CostFn != nil { + // // Reduce the cost-cap by L1 rollup cost of the first tx if necessary. Other txs will get filtered out afterwards. + // el := list.txs.FirstElement() + // l1Cost = pool.l1CostFn(el.RollupDataGas()) + // } + // Drop all transactions that are too costly (low balance or out of gas) + drops, _ := pool.filter(list, addr, gasLimit) for _, tx := range drops { hash := tx.Hash() pool.all.Remove(hash) @@ -1704,16 +1701,15 @@ func (pool *LegacyPool) demoteUnexecutables() { pool.all.Remove(hash) log.Trace("Removed old pending transaction", "hash", hash) } - balance := pool.currentState.GetBalance(addr) - if !list.Empty() && pool.l1CostFn != nil { - // Reduce the cost-cap by L1 rollup cost of the first tx if necessary. Other txs will get filtered out afterwards. - el := list.txs.FirstElement() - if l1Cost := pool.l1CostFn(el.RollupDataGas()); l1Cost != nil { - balance = new(big.Int).Sub(balance, l1Cost) // negative big int is fine - } - } + // TODO: manage l1cost + // var l1Cost *big.Int + // if !list.Empty() && pool.l1CostFn != nil { + // // Reduce the cost-cap by L1 rollup cost of the first tx if necessary. Other txs will get filtered out afterwards. + // el := list.txs.FirstElement() + // l1Cost = pool.l1CostFn(el.RollupDataGas()) + // } // Drop all transactions that are too costly (low balance or out of gas), and queue any invalids back for later - drops, invalids := list.Filter(balance, gasLimit) + drops, invalids := pool.filter(list, addr, gasLimit) for _, tx := range drops { hash := tx.Hash() log.Trace("Removed unpayable pending transaction", "hash", hash) diff --git a/core/txpool/legacypool/legacypool_test.go b/core/txpool/legacypool/legacypool_test.go index 43dfeee92c..11d73b2d64 100644 --- a/core/txpool/legacypool/legacypool_test.go +++ b/core/txpool/legacypool/legacypool_test.go @@ -198,8 +198,8 @@ func validatePoolInternals(pool *LegacyPool) error { if nonce := pool.pendingNonces.get(addr); nonce != last+1 { return fmt.Errorf("pending nonce mismatch: have %v, want %v", nonce, last+1) } - if txs.totalcost.Cmp(common.Big0) < 0 { - return fmt.Errorf("totalcost went negative: %v", txs.totalcost) + if txs.TotalCostFor(nil).Cmp(common.Big0) < 0 { + return fmt.Errorf("totalcost went negative: %v", txs.TotalCostFor(nil)) } } return nil @@ -2030,7 +2030,7 @@ func TestDualHeapEviction(t *testing.T) { add(false) for baseFee = 0; baseFee <= 1000; baseFee += 100 { - pool.priced.SetBaseFee(big.NewInt(int64(baseFee))) + pool.priced.SetBaseFeeAndRates(big.NewInt(int64(baseFee)), nil) add(true) check(highCap, "fee cap") add(false) diff --git a/core/txpool/legacypool/list.go b/core/txpool/legacypool/list.go index c61f7a0f8c..8e6fc303d3 100644 --- a/core/txpool/legacypool/list.go +++ b/core/txpool/legacypool/list.go @@ -18,7 +18,6 @@ package legacypool import ( "container/heap" - "math" "math/big" "sort" "sync" @@ -26,6 +25,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/exchange" "github.com/ethereum/go-ethereum/core/txpool" "github.com/ethereum/go-ethereum/core/types" ) @@ -281,9 +281,11 @@ type list struct { strict bool // Whether nonces are strictly continuous or not txs *sortedMap // Heap indexed sorted hash map of the transactions - costcap *big.Int // Price of the highest costing transaction (reset only if exceeds balance) - gascap uint64 // Gas limit of the highest spending transaction (reset only if exceeds block limit) - totalcost *big.Int // Total cost of all transactions in the list + costCap map[common.Address]*big.Int // Price of the highest costing transaction per currency (reset only if exceeds balance) + gascap uint64 // Gas limit of the highest spending transaction (reset only if exceeds block limit) + + // Celo additions for multi currency + totalCost map[common.Address]*big.Int // Total cost of all transactions in the list (by currency) } // newList create a new transaction list for maintaining nonce-indexable fast, @@ -292,8 +294,8 @@ func newList(strict bool) *list { return &list{ strict: strict, txs: newSortedMap(), - costcap: new(big.Int), - totalcost: new(big.Int), + costCap: make(map[common.Address]*big.Int), + totalCost: make(map[common.Address]*big.Int), } } @@ -308,11 +310,12 @@ func (l *list) Contains(nonce uint64) bool { // // If the new transaction is accepted into the list, the lists' cost and gas // thresholds are also potentially updated. -func (l *list) Add(tx *types.Transaction, priceBump uint64, l1CostFn txpool.L1CostFunc) (bool, *types.Transaction) { +func (l *list) Add(tx *types.Transaction, priceBump uint64, _ txpool.L1CostFunc, rates common.ExchangeRates) (bool, *types.Transaction) { // If there's an older better transaction, abort old := l.txs.Get(tx.Nonce()) if old != nil { - if old.GasFeeCapCmp(tx) >= 0 || old.GasTipCapCmp(tx) >= 0 { + // Short circuit when it's clear that the new tx is worse + if common.AreSameAddress(old.FeeCurrency(), tx.FeeCurrency()) && (old.GasFeeCapCmp(tx) >= 0 || old.GasTipCapCmp(tx) >= 0) { return false, nil } // thresholdFeeCap = oldFC * (100 + priceBump) / 100 @@ -325,27 +328,33 @@ func (l *list) Add(tx *types.Transaction, priceBump uint64, l1CostFn txpool.L1Co thresholdFeeCap := aFeeCap.Div(aFeeCap, b) thresholdTip := aTip.Div(aTip, b) + var thresholdFeeCapInCurrency = thresholdFeeCap + var thresholdTipInCurrency = thresholdTip + if tx.FeeCurrency() != old.FeeCurrency() { + thresholdFeeCapInCurrency = exchange.ConvertCurrency(rates, thresholdFeeCap, old.FeeCurrency(), tx.FeeCurrency()) + thresholdTipInCurrency = exchange.ConvertCurrency(rates, thresholdTip, old.FeeCurrency(), tx.FeeCurrency()) + } // We have to ensure that both the new fee cap and tip are higher than the // old ones as well as checking the percentage threshold to ensure that // this is accurate for low (Wei-level) gas price replacements. - if tx.GasFeeCapIntCmp(thresholdFeeCap) < 0 || tx.GasTipCapIntCmp(thresholdTip) < 0 { + if tx.GasFeeCapIntCmp(thresholdFeeCapInCurrency) < 0 || tx.GasTipCapIntCmp(thresholdTipInCurrency) < 0 { return false, nil } // Old is being replaced, subtract old cost l.subTotalCost([]*types.Transaction{old}) } // Add new tx cost to totalcost - l.totalcost.Add(l.totalcost, tx.Cost()) - if l1CostFn != nil { - if l1Cost := l1CostFn(tx.RollupDataGas()); l1Cost != nil { // add rollup cost - l.totalcost.Add(l.totalcost, l1Cost) - } - } + tc := l.totalCostVar(tx.FeeCurrency()) + tc.Add(tc, tx.Cost()) + // TODO: manage l1 cost + // if l1CostFn != nil { + // if l1Cost := l1CostFn(tx.RollupDataGas()); l1Cost != nil { // add rollup cost + // tc.Add(tc, l1Cost) + // } + // } // Otherwise overwrite the old transaction with the current one l.txs.Put(tx) - if cost := tx.Cost(); l.costcap.Cmp(cost) < 0 { - l.costcap = cost - } + l.updateCostCapFor(tx.FeeCurrency(), tx.Cost()) if gas := tx.Gas(); l.gascap < gas { l.gascap = gas } @@ -370,33 +379,24 @@ func (l *list) Forward(threshold uint64) types.Transactions { // a point in calculating all the costs or if the balance covers all. If the threshold // is lower than the costgas cap, the caps will be reset to a new high after removing // the newly invalidated transactions. -func (l *list) Filter(costLimit *big.Int, gasLimit uint64) (types.Transactions, types.Transactions) { +func (l *list) Filter(costLimits map[common.Address]*big.Int, gasLimit uint64) (types.Transactions, types.Transactions) { // If all transactions are below the threshold, short circuit - if l.costcap.Cmp(costLimit) <= 0 && l.gascap <= gasLimit { + if l.costCapsLowerThan(costLimits) && l.gascap <= gasLimit { return nil, nil } - l.costcap = new(big.Int).Set(costLimit) // Lower the caps to the thresholds + l.setCapsTo(costLimits) // Lower the caps to the thresholds l.gascap = gasLimit // Filter out all the transactions above the account's funds removed := l.txs.Filter(func(tx *types.Transaction) bool { - return tx.Gas() > gasLimit || tx.Cost().Cmp(costLimit) > 0 + return tx.Gas() > gasLimit || tx.Cost().Cmp(l.costCapFor(tx.FeeCurrency())) > 0 }) if len(removed) == 0 { return nil, nil } - var invalids types.Transactions - // If the list was strict, filter anything above the lowest nonce - if l.strict { - lowest := uint64(math.MaxUint64) - for _, tx := range removed { - if nonce := tx.Nonce(); lowest > nonce { - lowest = nonce - } - } - invalids = l.txs.filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) - } + + invalids := l.dropInvalidsAfterRemovalAndReheap(removed) // Reset total cost l.subTotalCost(removed) l.subTotalCost(invalids) @@ -471,7 +471,8 @@ func (l *list) LastElement() *types.Transaction { // total cost of all transactions. func (l *list) subTotalCost(txs []*types.Transaction) { for _, tx := range txs { - l.totalcost.Sub(l.totalcost, tx.Cost()) + tc := l.totalCostVar(tx.FeeCurrency()) + tc.Sub(tc, tx.Cost()) } } @@ -482,6 +483,8 @@ func (l *list) subTotalCost(txs []*types.Transaction) { type priceHeap struct { baseFee *big.Int // heap should always be re-sorted after baseFee is changed list []*types.Transaction + + txComparator TxComparator } func (h *priceHeap) Len() int { return len(h.list) } @@ -499,18 +502,7 @@ func (h *priceHeap) Less(i, j int) bool { } func (h *priceHeap) cmp(a, b *types.Transaction) int { - if h.baseFee != nil { - // Compare effective tips if baseFee is specified - if c := a.EffectiveGasTipCmp(b, h.baseFee); c != 0 { - return c - } - } - // Compare fee caps if baseFee is not specified or effective tips are equal - if c := a.GasFeeCapCmp(b); c != 0 { - return c - } - // Compare tips if effective tips and fee caps are equal - return a.GasTipCapCmp(b) + return h.txComparator(a, b, h.baseFee) } func (h *priceHeap) Push(x interface{}) { @@ -545,6 +537,9 @@ type pricedList struct { all *lookup // Pointer to the map of all transactions urgent, floating priceHeap // Heaps of prices of all the stored **remote** transactions reheapMu sync.Mutex // Mutex asserts that only one routine is reheaping the list + + // Celo specific + rates common.ExchangeRates // current exchange rates } const ( @@ -553,11 +548,13 @@ const ( floatingRatio = 1 ) -// newPricedList creates a new price-sorted transaction heap. func newPricedList(all *lookup) *pricedList { - return &pricedList{ + p := &pricedList{ all: all, } + p.floating.txComparator = p.compareWithRates + p.urgent.txComparator = p.compareWithRates + return p } // Put inserts a new transaction into the heap. @@ -684,9 +681,10 @@ func (l *pricedList) Reheap() { reheapTimer.Update(time.Since(start)) } -// SetBaseFee updates the base fee and triggers a re-heap. Note that Removed is not +// SetBaseFeeAndRates updates the base fee and triggers a re-heap. Note that Removed is not // necessary to call right before SetBaseFee when processing a new block. -func (l *pricedList) SetBaseFee(baseFee *big.Int) { +func (l *pricedList) SetBaseFeeAndRates(baseFee *big.Int, rates common.ExchangeRates) { l.urgent.baseFee = baseFee + l.rates = rates l.Reheap() } diff --git a/core/txpool/legacypool/list_test.go b/core/txpool/legacypool/list_test.go index b1f6ec305d..ebfb8e3abf 100644 --- a/core/txpool/legacypool/list_test.go +++ b/core/txpool/legacypool/list_test.go @@ -17,7 +17,6 @@ package legacypool import ( - "math/big" "math/rand" "testing" @@ -38,7 +37,7 @@ func TestStrictListAdd(t *testing.T) { // Insert the transactions in a random order list := newList(true) for _, v := range rand.Perm(len(txs)) { - list.Add(txs[v], DefaultConfig.PriceBump, nil) + list.Add(txs[v], DefaultConfig.PriceBump, nil, nil) } // Verify internal state if len(list.txs.items) != len(txs) { @@ -60,13 +59,11 @@ func BenchmarkListAdd(b *testing.B) { txs[i] = transaction(uint64(i), 0, key) } // Insert the transactions in a random order - priceLimit := big.NewInt(int64(DefaultConfig.PriceLimit)) b.ResetTimer() for i := 0; i < b.N; i++ { list := newList(true) for _, v := range rand.Perm(len(txs)) { - list.Add(txs[v], DefaultConfig.PriceBump, nil) - list.Filter(priceLimit, DefaultConfig.PriceBump) + list.Add(txs[v], DefaultConfig.PriceBump, nil, nil) } } } diff --git a/core/txpool/validation.go b/core/txpool/validation.go index 33cc1df78b..d8e67a30b3 100644 --- a/core/txpool/validation.go +++ b/core/txpool/validation.go @@ -192,7 +192,7 @@ func validateBlobSidecar(hashes []common.Hash, sidecar *types.BlobTxSidecar) err // ValidationOptionsWithState define certain differences between stateful transaction // validation across the different pools without having to duplicate those checks. type ValidationOptionsWithState struct { - State *state.StateDB // State database to check nonces and balances against + State *state.StateDB // State database to check nonces // FirstNonceGap is an optional callback to retrieve the first nonce gap in // the list of pooled transactions of a specific account. If this method is @@ -215,6 +215,9 @@ type ValidationOptionsWithState struct { // L1CostFn is an optional extension, to validate L1 rollup costs of a tx L1CostFn L1CostFunc + + // ExistingBalance for a currency, to check for balance to cover transaction costs. + ExistingBalance func(addr common.Address, feeCurrency *common.Address) *big.Int } // ValidateTransactionWithState is a helper method to check whether a transaction @@ -222,7 +225,7 @@ type ValidationOptionsWithState struct { // // This check is public to allow different transaction pools to check the stateful // rules without duplicating code and running the risk of missed updates. -func ValidateTransactionWithState(tx *types.Transaction, signer types.Signer, opts *CeloValidationOptionsWithState) error { +func ValidateTransactionWithState(tx *types.Transaction, signer types.Signer, opts *ValidationOptionsWithState) error { // Ensure the transaction adheres to nonce ordering from, err := signer.Sender(tx) // already validated (and cached), but cleaner to check if err != nil { @@ -242,7 +245,7 @@ func ValidateTransactionWithState(tx *types.Transaction, signer types.Signer, op } // Ensure the transactor has enough funds to cover the transaction costs var ( - balance = opts.FeeCurrencyValidator.Balance(opts.State, from, tx.FeeCurrency()) + balance = opts.ExistingBalance(from, tx.FeeCurrency()) cost = tx.Cost() ) if opts.L1CostFn != nil { diff --git a/core/vm/evm.go b/core/vm/evm.go index 3266f3c241..71ae7ef5b0 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -93,17 +93,7 @@ type BlockContext struct { Random *common.Hash // Provides information for PREVRANDAO // Celo specific information - ExchangeRates map[common.Address]*big.Rat -} - -func (bc BlockContext) IsCurrencyWhitelisted(feeCurrency *common.Address) bool { - if feeCurrency == nil { - return true - } - - // Check if fee currency is registered - _, ok := bc.ExchangeRates[*feeCurrency] - return ok + ExchangeRates common.ExchangeRates } // TxContext provides the EVM with information about a transaction. diff --git a/e2e_test/js-tests/test_viem_tx.mjs b/e2e_test/js-tests/test_viem_tx.mjs index 04b5bf23b0..db814a69cf 100644 --- a/e2e_test/js-tests/test_viem_tx.mjs +++ b/e2e_test/js-tests/test_viem_tx.mjs @@ -34,6 +34,56 @@ const walletClient = createWalletClient({ transport: http(), }); +const testNonceBump = async (firstCap, firstCurrency, secondCap, secondCurrency, shouldReplace) => { + const syncBarrierRequest = await walletClient.prepareTransactionRequest({ + account, + to: "0x00000000000000000000000000000000DeaDBeef", + value: 2, + gas: 22000, + }) + const firstTxHash = await walletClient.sendTransaction({ + account, + to: "0x00000000000000000000000000000000DeaDBeef", + value: 2, + gas: 90000, + maxFeePerGas: firstCap, + maxPriorityFeePerGas: firstCap, + nonce: syncBarrierRequest.nonce + 1, + feeCurrency: firstCurrency, + }); + var secondTxHash; + try { + secondTxHash = await walletClient.sendTransaction({ + account, + to: "0x00000000000000000000000000000000DeaDBeef", + value: 3, + gas: 90000, + maxFeePerGas: secondCap, + maxPriorityFeePerGas: secondCap, + nonce: syncBarrierRequest.nonce + 1, + feeCurrency: secondCurrency, + }); + } catch (err) { + // If shouldReplace, no error should be thrown + // If shouldReplace == false, exactly the underpriced error should be thrown + if (err.cause.details != 'replacement transaction underpriced' || shouldReplace) { + throw err; // Only throw if unexpected error. + } + } + const syncBarrierSignature = await walletClient.signTransaction(syncBarrierRequest); + const barrierTxHash = await walletClient.sendRawTransaction({ + serializedTransaction: syncBarrierSignature, + }) + await publicClient.waitForTransactionReceipt({ hash: barrierTxHash }); + if (shouldReplace) { + // The new transaction was included. + await publicClient.waitForTransactionReceipt({ hash: secondTxHash }); + } else { + // The original transaction was not replaced. + await publicClient.waitForTransactionReceipt({ hash: firstTxHash }); + } +} + describe("viem send tx", () => { it("send basic tx and check receipt", async () => { const request = await walletClient.prepareTransactionRequest({ @@ -62,4 +112,47 @@ describe("viem send tx", () => { }); const receipt = await publicClient.waitForTransactionReceipt({ hash }); }).timeout(10_000); + + it("send overlapping nonce tx in different currencies", async () => { + const priceBump = 1.10 + const rate = 2; + // Native to FEE_CURRENCY + const nativeCap = 30_000_000_000; + const bumpCurrencyCap = BigInt(Math.round(nativeCap * rate * priceBump)); + const failToBumpCurrencyCap = BigInt(Math.round(nativeCap * rate * priceBump) - 1); + // FEE_CURRENCY to Native + const currencyCap = 60_000_000_000; + const bumpNativeCap = BigInt(Math.round((currencyCap * priceBump) / rate)); + const failToBumpNativeCap = BigInt(Math.round((currencyCap * priceBump) / rate) - 1); + const tokenCurrency = process.env.FEE_CURRENCY; + const nativeCurrency = null; + await testNonceBump(nativeCap, nativeCurrency, bumpCurrencyCap, tokenCurrency, true); + await testNonceBump(nativeCap, nativeCurrency, failToBumpCurrencyCap, tokenCurrency, false); + await testNonceBump(currencyCap, tokenCurrency, bumpNativeCap, nativeCurrency, true); + await testNonceBump(currencyCap, tokenCurrency, failToBumpNativeCap, nativeCurrency, false); + }).timeout(10_000); + + it("send tx with non-whitelisted fee currency", async () => { + const request = await walletClient.prepareTransactionRequest({ + account, + to: "0x00000000000000000000000000000000DeaDBeef", + value: 2, + gas: 90000, + feeCurrency: "0x000000000000000000000000000000000badc310", + }); + const signature = await walletClient.signTransaction(request); + try { + await walletClient.sendRawTransaction({ + serializedTransaction: signature, + }); + assert.fail("Failed to filter nonwhitelisted feeCurrency"); + } catch(err) { + // TODO: find a better way to check the error type + if (err.cause.details == "Fee currency given is not whitelisted at current block") { + // Test success + } else { + throw err + } + } + }).timeout(10_000); });