diff --git a/CHANGELOG.md b/CHANGELOG.md index 998477579..f683644a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - [\#591](https://github.com/cosmos/evm/pull/591) CheckTxHandler should handle "invalid nonce" tx - [\#643](https://github.com/cosmos/evm/pull/643) Support for mnemonic source (file, stdin,etc) flag in key add command. - [\#645](https://github.com/cosmos/evm/pull/645) Align precise bank keeper for correct decimal conversion in evmd. +- [\#656](https://github.com/cosmos/evm/pull/656) Fix race condition in concurrent usage of mempool StateAt and NotifyNewBlock methods. ### IMPROVEMENTS diff --git a/mempool/blockchain.go b/mempool/blockchain.go index bd4f2f02c..72b434f3d 100644 --- a/mempool/blockchain.go +++ b/mempool/blockchain.go @@ -3,6 +3,7 @@ package mempool import ( "fmt" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -24,8 +25,8 @@ import ( ) var ( - _ txpool.BlockChain = Blockchain{} - _ legacypool.BlockChain = Blockchain{} + _ txpool.BlockChain = &Blockchain{} + _ legacypool.BlockChain = &Blockchain{} ) // Blockchain implements the BlockChain interface required by Ethereum transaction pools. @@ -42,12 +43,13 @@ type Blockchain struct { blockGasLimit uint64 previousHeaderHash common.Hash latestCtx sdk.Context + mu sync.RWMutex } -// newBlockchain creates a new Blockchain instance that bridges Cosmos SDK state with Ethereum mempools. +// NewBlockchain creates a new Blockchain instance that bridges Cosmos SDK state with Ethereum mempools. // The getCtxCallback function provides access to Cosmos SDK contexts at different heights, vmKeeper manages EVM state, // and feeMarketKeeper handles fee market operations like base fee calculations. -func newBlockchain(ctx func(height int64, prove bool) (sdk.Context, error), logger log.Logger, vmKeeper VMKeeperI, feeMarketKeeper FeeMarketKeeperI, blockGasLimit uint64) *Blockchain { +func NewBlockchain(ctx func(height int64, prove bool) (sdk.Context, error), logger log.Logger, vmKeeper VMKeeperI, feeMarketKeeper FeeMarketKeeperI, blockGasLimit uint64) *Blockchain { // Add the blockchain name to the logger logger = logger.With(log.ModuleKey, "Blockchain") @@ -70,7 +72,7 @@ func newBlockchain(ctx func(height int64, prove bool) (sdk.Context, error), logg // Config returns the Ethereum chain configuration. It should only be called after the chain is initialized. // This provides the necessary parameters for EVM execution and transaction validation. -func (b Blockchain) Config() *params.ChainConfig { +func (b *Blockchain) Config() *params.ChainConfig { return evmtypes.GetEthChainConfig() } @@ -78,7 +80,7 @@ func (b Blockchain) Config() *params.ChainConfig { // It constructs an Ethereum-compatible header from the current Cosmos SDK context, // including block height, timestamp, gas limits, and base fee (if London fork is active). // Returns a zero header as placeholder if the context is not yet available. -func (b Blockchain) CurrentBlock() *types.Header { +func (b *Blockchain) CurrentBlock() *types.Header { ctx, err := b.GetLatestContext() if err != nil { return b.zeroHeader @@ -86,7 +88,8 @@ func (b Blockchain) CurrentBlock() *types.Header { blockHeight := ctx.BlockHeight() // prevent the reorg from triggering after a restart since previousHeaderHash is stored as an in-memory variable - if blockHeight > 1 && b.previousHeaderHash == (common.Hash{}) { + previousHeaderHash := b.getPreviousHeaderHash() + if blockHeight > 1 && previousHeaderHash == (common.Hash{}) { return b.zeroHeader } @@ -99,7 +102,7 @@ func (b Blockchain) CurrentBlock() *types.Header { Time: uint64(blockTime), // #nosec G115 -- overflow not a concern with unix time GasLimit: b.blockGasLimit, GasUsed: gasUsed, - ParentHash: b.previousHeaderHash, + ParentHash: previousHeaderHash, Root: appHash, // we actually don't care that this isn't the getCtxCallback header, as long as we properly track roots and parent roots to prevent the reorg from triggering Difficulty: big.NewInt(0), // 0 difficulty on PoS } @@ -139,7 +142,7 @@ func (b Blockchain) CurrentBlock() *types.Header { // Cosmos chains have instant finality, so this method should only be called for the genesis block (block 0) // or block 1, as reorgs never occur. Any other call indicates a bug in the mempool logic. // Panics if called for blocks beyond block 1, as this would indicate an attempted reorg. -func (b Blockchain) GetBlock(_ common.Hash, _ uint64) *types.Block { +func (b *Blockchain) GetBlock(_ common.Hash, _ uint64) *types.Block { currBlock := b.CurrentBlock() blockNumber := currBlock.Number.Int64() @@ -161,7 +164,7 @@ func (b Blockchain) GetBlock(_ common.Hash, _ uint64) *types.Block { // SubscribeChainHeadEvent allows subscribers to receive notifications when new blocks are finalized. // Returns a subscription that will receive ChainHeadEvent notifications via the provided channel. -func (b Blockchain) SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription { +func (b *Blockchain) SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription { b.logger.Debug("new chain head event subscription created") return b.chainHeadFeed.Subscribe(ch) } @@ -170,19 +173,19 @@ func (b Blockchain) SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event func (b *Blockchain) NotifyNewBlock() { latestCtx, err := b.newLatestContext() if err != nil { - b.latestCtx = sdk.Context{} + b.setLatestContext(sdk.Context{}) b.logger.Debug("failed to get latest context, notifying chain head", "error", err) } - b.latestCtx = latestCtx + b.setLatestContext(latestCtx) header := b.CurrentBlock() headerHash := header.Hash() b.logger.Debug("notifying new block", "block_number", header.Number.String(), "block_hash", headerHash.Hex(), - "previous_hash", b.previousHeaderHash.Hex()) + "previous_hash", b.getPreviousHeaderHash().Hex()) - b.previousHeaderHash = headerHash + b.setPreviousHeaderHash(headerHash) b.chainHeadFeed.Send(core.ChainHeadEvent{Header: header}) b.logger.Debug("chain head event sent to feed") @@ -192,7 +195,7 @@ func (b *Blockchain) NotifyNewBlock() { // In practice, this always returns the most recent state since the mempool // only needs current state for validation. Historical state access is not supported // as it's never required by the txpool. -func (b Blockchain) StateAt(hash common.Hash) (vm.StateDB, error) { +func (b *Blockchain) StateAt(hash common.Hash) (vm.StateDB, error) { b.logger.Debug("StateAt called", "requested_hash", hash.Hex()) // This is returned at block 0, before the context is available. @@ -215,10 +218,30 @@ func (b Blockchain) StateAt(hash common.Hash) (vm.StateDB, error) { return stateDB, nil } +func (b *Blockchain) getPreviousHeaderHash() common.Hash { + b.mu.RLock() + defer b.mu.RUnlock() + return b.previousHeaderHash +} + +func (b *Blockchain) setPreviousHeaderHash(h common.Hash) { + b.mu.Lock() + defer b.mu.Unlock() + b.previousHeaderHash = h +} + +func (b *Blockchain) setLatestContext(ctx sdk.Context) { + b.mu.Lock() + defer b.mu.Unlock() + b.latestCtx = ctx +} + // GetLatestContext returns the latest context as updated by the block, // or attempts to retrieve it again if unavailable. -func (b Blockchain) GetLatestContext() (sdk.Context, error) { +func (b *Blockchain) GetLatestContext() (sdk.Context, error) { b.logger.Debug("getting latest context") + b.mu.RLock() + defer b.mu.RUnlock() if b.latestCtx.Context() != nil { return b.latestCtx, nil @@ -229,7 +252,7 @@ func (b Blockchain) GetLatestContext() (sdk.Context, error) { // newLatestContext retrieves the most recent query context from the application. // This provides access to the current blockchain state for transaction validation and execution. -func (b Blockchain) newLatestContext() (sdk.Context, error) { +func (b *Blockchain) newLatestContext() (sdk.Context, error) { b.logger.Debug("getting latest context") ctx, err := b.getCtxCallback(0, false) diff --git a/mempool/blockchain_test.go b/mempool/blockchain_test.go new file mode 100644 index 000000000..d4f854802 --- /dev/null +++ b/mempool/blockchain_test.go @@ -0,0 +1,123 @@ +package mempool_test + +import ( + "math/big" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + + "github.com/cosmos/evm/config" + "github.com/cosmos/evm/mempool" + "github.com/cosmos/evm/mempool/mocks" + "github.com/cosmos/evm/x/vm/statedb" + vmtypes "github.com/cosmos/evm/x/vm/types" + + "cosmossdk.io/log" + storetypes "cosmossdk.io/store/types" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// createMockContext creates a basic mock context for testing +func createMockContext() sdk.Context { + return sdk.Context{}. + WithBlockTime(time.Now()). + WithBlockHeader(cmtproto.Header{AppHash: []byte("00000000000000000000000000000000")}). + WithBlockHeight(1) +} + +// TestBlockchainRaceCondition tests concurrent access to NotifyNewBlock and StateAt +// to ensure there are no race conditions between these operations. +func TestBlockchainRaceCondition(t *testing.T) { + logger := log.NewNopLogger() + + // Create mock keepers using generated mocks + mockVMKeeper := mocks.NewVmKeeper(t) + mockFeeMarketKeeper := mocks.NewFeeMarketKeeper(t) + + // Set up mock expectations for methods that will be called + mockVMKeeper.On("GetBaseFee", mock.Anything).Return(big.NewInt(1000000000)).Maybe() // 1 gwei + mockFeeMarketKeeper.On("GetBlockGasWanted", mock.Anything).Return(uint64(10000000)).Maybe() // 10M gas + mockVMKeeper.On("GetParams", mock.Anything).Return(vmtypes.DefaultParams()).Maybe() + mockVMKeeper.On("GetAccount", mock.Anything, common.Address{}).Return(&statedb.Account{}).Maybe() + mockVMKeeper.On("GetState", mock.Anything, common.Address{}, common.Hash{}).Return(common.Hash{}).Maybe() + mockVMKeeper.On("GetCode", mock.Anything, common.Hash{}).Return([]byte{}).Maybe() + mockVMKeeper.On("GetCodeHash", mock.Anything, common.Address{}).Return(common.Hash{}).Maybe() + mockVMKeeper.On("ForEachStorage", mock.Anything, common.Address{}, mock.AnythingOfType("func(common.Hash, common.Hash) bool")).Maybe() + mockVMKeeper.On("KVStoreKeys").Return(make(map[string]*storetypes.KVStoreKey)).Maybe() + + err := vmtypes.NewEVMConfigurator().WithEVMCoinInfo(config.ChainsCoinInfo[config.EighteenDecimalsChainID]).Configure() + require.NoError(t, err) + + // Mock context callback that returns a valid context + getCtxCallback := func(height int64, prove bool) (sdk.Context, error) { + return createMockContext(), nil + } + + blockchain := mempool.NewBlockchain( + getCtxCallback, + logger, + mockVMKeeper, + mockFeeMarketKeeper, + 21000000, // block gas limit + ) + + const numIterations = 100 + var wg sync.WaitGroup + + // Channel to collect any errors from goroutines + errChan := make(chan error, numIterations*2) + + // Start goroutine that calls NotifyNewBlock repeatedly + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numIterations; i++ { + blockchain.NotifyNewBlock() + // Small delay to allow interleaving + time.Sleep(time.Microsecond) + } + }() + + // Start goroutine that calls StateAt repeatedly + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numIterations; i++ { + hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + _, err := blockchain.StateAt(hash) + if err != nil { + errChan <- err + return + } + // Small delay to allow interleaving + time.Sleep(time.Microsecond) + } + }() + + // Wait for both goroutines to complete + wg.Wait() + close(errChan) + + // Check for any errors + for err := range errChan { + require.NoError(t, err) + } + + // Basic validation - ensure blockchain still functions correctly after concurrent access + header := blockchain.CurrentBlock() + require.NotNil(t, header) + require.Equal(t, int64(1), header.Number.Int64()) + + // Ensure StateAt still works after concurrent access + hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + stateDB, err := blockchain.StateAt(hash) + require.NoError(t, err) + require.NotNil(t, stateDB) +} diff --git a/mempool/mempool.go b/mempool/mempool.go index a3a7aa62c..583b42b23 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -108,7 +108,7 @@ func NewExperimentalEVMMempool(getCtxCallback func(height int64, prove bool) (sd config.BlockGasLimit = fallbackBlockGasLimit } - blockchain = newBlockchain(getCtxCallback, logger, vmKeeper, feeMarketKeeper, config.BlockGasLimit) + blockchain = NewBlockchain(getCtxCallback, logger, vmKeeper, feeMarketKeeper, config.BlockGasLimit) // Create txPool from configuration legacyConfig := legacypool.DefaultConfig diff --git a/mempool/mocks/FeeMarketKeeper.go b/mempool/mocks/FeeMarketKeeper.go new file mode 100644 index 000000000..99b900563 --- /dev/null +++ b/mempool/mocks/FeeMarketKeeper.go @@ -0,0 +1,45 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + types "github.com/cosmos/cosmos-sdk/types" + mock "github.com/stretchr/testify/mock" +) + +// FeeMarketKeeper is an autogenerated mock type for the FeeMarketKeeperI type +type FeeMarketKeeper struct { + mock.Mock +} + +// GetBlockGasWanted provides a mock function with given fields: ctx +func (_m *FeeMarketKeeper) GetBlockGasWanted(ctx types.Context) uint64 { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetBlockGasWanted") + } + + var r0 uint64 + if rf, ok := ret.Get(0).(func(types.Context) uint64); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// NewFeeMarketKeeper creates a new instance of FeeMarketKeeper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewFeeMarketKeeper(t interface { + mock.TestingT + Cleanup(func()) +}) *FeeMarketKeeper { + mock := &FeeMarketKeeper{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mempool/mocks/VMKeeper.go b/mempool/mocks/VMKeeper.go new file mode 100644 index 000000000..0644b8293 --- /dev/null +++ b/mempool/mocks/VMKeeper.go @@ -0,0 +1,243 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + big "math/big" + + mempool "github.com/cosmos/evm/mempool" + common "github.com/ethereum/go-ethereum/common" + + mock "github.com/stretchr/testify/mock" + + statedb "github.com/cosmos/evm/x/vm/statedb" + + storetypes "cosmossdk.io/store/types" + + types "github.com/cosmos/cosmos-sdk/types" + + vmtypes "github.com/cosmos/evm/x/vm/types" +) + +// VmKeeper is an autogenerated mock type for the VMKeeperI type +type VmKeeper struct { + mock.Mock +} + +// DeleteAccount provides a mock function with given fields: ctx, addr +func (_m *VmKeeper) DeleteAccount(ctx types.Context, addr common.Address) error { + ret := _m.Called(ctx, addr) + + if len(ret) == 0 { + panic("no return value specified for DeleteAccount") + } + + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, common.Address) error); ok { + r0 = rf(ctx, addr) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteCode provides a mock function with given fields: ctx, codeHash +func (_m *VmKeeper) DeleteCode(ctx types.Context, codeHash []byte) { + _m.Called(ctx, codeHash) +} + +// DeleteState provides a mock function with given fields: ctx, addr, key +func (_m *VmKeeper) DeleteState(ctx types.Context, addr common.Address, key common.Hash) { + _m.Called(ctx, addr, key) +} + +// ForEachStorage provides a mock function with given fields: ctx, addr, cb +func (_m *VmKeeper) ForEachStorage(ctx types.Context, addr common.Address, cb func(common.Hash, common.Hash) bool) { + _m.Called(ctx, addr, cb) +} + +// GetAccount provides a mock function with given fields: ctx, addr +func (_m *VmKeeper) GetAccount(ctx types.Context, addr common.Address) *statedb.Account { + ret := _m.Called(ctx, addr) + + if len(ret) == 0 { + panic("no return value specified for GetAccount") + } + + var r0 *statedb.Account + if rf, ok := ret.Get(0).(func(types.Context, common.Address) *statedb.Account); ok { + r0 = rf(ctx, addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*statedb.Account) + } + } + + return r0 +} + +// GetBaseFee provides a mock function with given fields: ctx +func (_m *VmKeeper) GetBaseFee(ctx types.Context) *big.Int { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetBaseFee") + } + + var r0 *big.Int + if rf, ok := ret.Get(0).(func(types.Context) *big.Int); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + return r0 +} + +// GetCode provides a mock function with given fields: ctx, codeHash +func (_m *VmKeeper) GetCode(ctx types.Context, codeHash common.Hash) []byte { + ret := _m.Called(ctx, codeHash) + + if len(ret) == 0 { + panic("no return value specified for GetCode") + } + + var r0 []byte + if rf, ok := ret.Get(0).(func(types.Context, common.Hash) []byte); ok { + r0 = rf(ctx, codeHash) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// GetCodeHash provides a mock function with given fields: ctx, addr +func (_m *VmKeeper) GetCodeHash(ctx types.Context, addr common.Address) common.Hash { + ret := _m.Called(ctx, addr) + + if len(ret) == 0 { + panic("no return value specified for GetCodeHash") + } + + var r0 common.Hash + if rf, ok := ret.Get(0).(func(types.Context, common.Address) common.Hash); ok { + r0 = rf(ctx, addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Hash) + } + } + + return r0 +} + +// GetParams provides a mock function with given fields: ctx +func (_m *VmKeeper) GetParams(ctx types.Context) vmtypes.Params { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetParams") + } + + var r0 vmtypes.Params + if rf, ok := ret.Get(0).(func(types.Context) vmtypes.Params); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(vmtypes.Params) + } + + return r0 +} + +// GetState provides a mock function with given fields: ctx, addr, key +func (_m *VmKeeper) GetState(ctx types.Context, addr common.Address, key common.Hash) common.Hash { + ret := _m.Called(ctx, addr, key) + + if len(ret) == 0 { + panic("no return value specified for GetState") + } + + var r0 common.Hash + if rf, ok := ret.Get(0).(func(types.Context, common.Address, common.Hash) common.Hash); ok { + r0 = rf(ctx, addr, key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Hash) + } + } + + return r0 +} + +// KVStoreKeys provides a mock function with no fields +func (_m *VmKeeper) KVStoreKeys() map[string]*storetypes.KVStoreKey { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for KVStoreKeys") + } + + var r0 map[string]*storetypes.KVStoreKey + if rf, ok := ret.Get(0).(func() map[string]*storetypes.KVStoreKey); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*storetypes.KVStoreKey) + } + } + + return r0 +} + +// SetAccount provides a mock function with given fields: ctx, addr, account +func (_m *VmKeeper) SetAccount(ctx types.Context, addr common.Address, account statedb.Account) error { + ret := _m.Called(ctx, addr, account) + + if len(ret) == 0 { + panic("no return value specified for SetAccount") + } + + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, common.Address, statedb.Account) error); ok { + r0 = rf(ctx, addr, account) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetCode provides a mock function with given fields: ctx, codeHash, code +func (_m *VmKeeper) SetCode(ctx types.Context, codeHash []byte, code []byte) { + _m.Called(ctx, codeHash, code) +} + +// SetEvmMempool provides a mock function with given fields: evmMempool +func (_m *VmKeeper) SetEvmMempool(evmMempool *mempool.ExperimentalEVMMempool) { + _m.Called(evmMempool) +} + +// SetState provides a mock function with given fields: ctx, addr, key, value +func (_m *VmKeeper) SetState(ctx types.Context, addr common.Address, key common.Hash, value []byte) { + _m.Called(ctx, addr, key, value) +} + +// NewVmKeeper creates a new instance of VmKeeper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewVmKeeper(t interface { + mock.TestingT + Cleanup(func()) +}) *VmKeeper { + mock := &VmKeeper{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}