diff --git a/common/types.go b/common/types.go index bf74e43716..8c517a5834 100644 --- a/common/types.go +++ b/common/types.go @@ -472,3 +472,5 @@ func (d *Decimal) UnmarshalJSON(input []byte) error { return err } } + +type ExchangeRates = map[Address]*big.Rat diff --git a/core/celo_backend.go b/core/celo_backend.go index 1d5a367d31..a54b1bf58c 100644 --- a/core/celo_backend.go +++ b/core/celo_backend.go @@ -14,14 +14,14 @@ import ( // CeloBackend provide a partial ContractBackend implementation, so that we can // access core contracts during block processing. type CeloBackend struct { - chainConfig *params.ChainConfig - state *state.StateDB + ChainConfig *params.ChainConfig + State *state.StateDB } // ContractCaller implementation func (b *CeloBackend) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) { - return b.state.GetCode(contract), nil + return b.State.GetCode(contract), nil } func (b *CeloBackend) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { @@ -44,7 +44,7 @@ func (b *CeloBackend) CallContract(ctx context.Context, call ethereum.CallMsg, b txCtx := vm.TxContext{} vmConfig := vm.Config{} - evm := vm.NewEVM(blockCtx, txCtx, b.state, b.chainConfig, vmConfig) + evm := vm.NewEVM(blockCtx, txCtx, b.State, b.ChainConfig, vmConfig) ret, _, err := evm.StaticCall(vm.AccountRef(evm.Origin), *call.To, call.Data, call.Gas) return ret, err diff --git a/core/celo_evm.go b/core/celo_evm.go index 4db27bed00..0fb33078a4 100644 --- a/core/celo_evm.go +++ b/core/celo_evm.go @@ -16,7 +16,7 @@ import ( ) // Returns the exchange rates for all gas currencies from CELO -func getExchangeRates(caller *CeloBackend) (map[common.Address]*big.Rat, error) { +func GetExchangeRates(caller bind.ContractCaller) (common.ExchangeRates, error) { exchangeRates := map[common.Address]*big.Rat{} whitelist, err := abigen.NewFeeCurrencyWhitelistCaller(contracts.FeeCurrencyWhitelistAddress, caller) if err != nil { @@ -57,7 +57,7 @@ func setCeloFieldsInBlockContext(blockContext *vm.BlockContext, header *types.He // Add fee currency exchange rates var err error - blockContext.ExchangeRates, err = getExchangeRates(caller) + blockContext.ExchangeRates, err = GetExchangeRates(caller) if err != nil { log.Error("Error fetching exchange rates!", "err", err) } diff --git a/core/txpool/legacypool/celo.go b/core/txpool/legacypool/celo.go new file mode 100644 index 0000000000..049ac27699 --- /dev/null +++ b/core/txpool/legacypool/celo.go @@ -0,0 +1,92 @@ +package legacypool + +import ( + "errors" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/contracts/celo/abigen" +) + +var ( + unitRate = big.NewRat(1, 1) +) + +// IsWhitelisted checks if a given fee currency is whitelisted +func IsWhitelisted(exchangeRates common.ExchangeRates, feeCurrency *common.Address) bool { + if feeCurrency == nil { + return true + } + _, ok := exchangeRates[*feeCurrency] + return ok +} + +// Compares values in different currencies +// nil currency => native currency +func CompareValue(exchangeRates common.ExchangeRates, val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) (int, error) { + // Short circuit if the fee currency is the same. + if areEqualAddresses(feeCurrency1, feeCurrency2) { + return val1.Cmp(val2), nil + } + + var exchangeRate1, exchangeRate2 *big.Rat + var ok bool + if feeCurrency1 == nil { + exchangeRate1 = unitRate + } else { + exchangeRate1, ok = exchangeRates[*feeCurrency1] + if !ok { + return 0, fmt.Errorf("fee currency not registered: %s", feeCurrency1.Hex()) + } + } + + if feeCurrency2 == nil { + exchangeRate2 = unitRate + } else { + exchangeRate2, ok = exchangeRates[*feeCurrency2] + if !ok { + return 0, fmt.Errorf("fee currency not registered: %s", feeCurrency1.Hex()) + } + } + + // Below code block is basically evaluating this comparison: + // val1 * exchangeRate1.denominator / exchangeRate1.numerator < val2 * exchangeRate2.denominator / exchangeRate2.numerator + // It will transform that comparison to this, to remove having to deal with fractional values. + // val1 * exchangeRate1.denominator * exchangeRate2.numerator < val2 * exchangeRate2.denominator * exchangeRate1.numerator + leftSide := new(big.Int).Mul( + val1, + new(big.Int).Mul( + exchangeRate1.Denom(), + exchangeRate2.Num(), + ), + ) + rightSide := new(big.Int).Mul( + val2, + new(big.Int).Mul( + exchangeRate2.Denom(), + exchangeRate1.Num(), + ), + ) + + return leftSide.Cmp(rightSide), nil +} + +func areEqualAddresses(addr1, addr2 *common.Address) bool { + return (addr1 == nil && addr2 == nil) || (addr1 != nil && addr2 != nil && *addr1 == *addr2) +} + +func GetBalanceOf(backend *bind.ContractCaller, account common.Address, feeCurrency common.Address) (*big.Int, error) { + token, err := abigen.NewFeeCurrencyCaller(feeCurrency, *backend) + if err != nil { + return nil, errors.New("failed to access fee currency token") + } + + balance, err := token.BalanceOf(&bind.CallOpts{}, account) + if err != nil { + return nil, errors.New("failed to access token balance") + } + + return balance, nil +} diff --git a/core/txpool/legacypool/celo_test.go b/core/txpool/legacypool/celo_test.go new file mode 100644 index 0000000000..49fa943f70 --- /dev/null +++ b/core/txpool/legacypool/celo_test.go @@ -0,0 +1,175 @@ +package legacypool + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +var ( + currA = common.HexToAddress("0xA") + currB = common.HexToAddress("0xB") + currX = common.HexToAddress("0xF") + exchangeRates = common.ExchangeRates{ + currA: big.NewRat(47, 100), + currB: big.NewRat(45, 100), + } +) + +func TestCompareFees(t *testing.T) { + type args struct { + val1 *big.Int + feeCurrency1 *common.Address + val2 *big.Int + feeCurrency2 *common.Address + } + tests := []struct { + name string + args args + wantResult int + wantErr bool + }{ + // Native currency + { + name: "Same amount of native currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: nil, + }, + wantResult: 0, + }, { + name: "Different amounts of native currency 1", + args: args{ + val1: big.NewInt(2), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: nil, + }, + wantResult: 1, + }, { + name: "Different amounts of native currency 2", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(5), + feeCurrency2: nil, + }, + wantResult: -1, + }, + // Mixed currency + { + name: "Same amount of mixed currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + wantResult: -1, + }, { + name: "Different amounts of mixed currency 1", + args: args{ + val1: big.NewInt(100), + feeCurrency1: nil, + val2: big.NewInt(47), + feeCurrency2: &currA, + }, + wantResult: 0, + }, { + name: "Different amounts of mixed currency 2", + args: args{ + val1: big.NewInt(100), + feeCurrency1: nil, + val2: big.NewInt(45), + feeCurrency2: &currB, + }, + wantResult: 0, + }, + // Two fee currencies + { + name: "Same amount of same currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + wantResult: 0, + }, { + name: "Different amounts of same currency 1", + args: args{ + val1: big.NewInt(3), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + wantResult: 1, + }, { + name: "Different amounts of same currency 2", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(7), + feeCurrency2: &currA, + }, + wantResult: -1, + }, + // Unregistered fee currency + { + name: "Different amounts of different currencies", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currX, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CompareValue(exchangeRates, tt.args.val1, tt.args.feeCurrency1, tt.args.val2, tt.args.feeCurrency2) + + if tt.wantErr && err == nil { + t.Error("Expected error in celoContextImpl.CompareFees") + } + if got != tt.wantResult { + t.Errorf("celoContextImpl.CompareFees() = %v, want %v", got, tt.wantResult) + } + }) + } +} + +func TestIsWhitelisted(t *testing.T) { + tests := []struct { + name string + feeCurrency *common.Address + want bool + }{ + { + name: "no fee currency", + feeCurrency: nil, + want: true, + }, + { + name: "valid fee currency", + feeCurrency: &currA, + want: true, + }, + { + name: "invalid fee currency", + feeCurrency: &currX, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsWhitelisted(exchangeRates, tt.feeCurrency); got != tt.want { + t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/vm/evm.go b/core/vm/evm.go index cb65356349..e1844b5f3e 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -93,7 +93,7 @@ type BlockContext struct { ExcessBlobGas *uint64 // ExcessBlobGas field in the header, needed to compute the data // Celo specific information - ExchangeRates map[common.Address]*big.Rat + ExchangeRates common.ExchangeRates } // TxContext provides the EVM with information about a transaction.