diff --git a/evm/chain.py b/evm/chain.py index 7447c296f3..d9d3de86ac 100644 --- a/evm/chain.py +++ b/evm/chain.py @@ -96,12 +96,6 @@ def configure(cls, name, vm_configuration, **overrides): # # Convenience and Helpers # - def get_state_db(self): - """ - Passthrough helper to the current VM class. - """ - return self.get_vm().state_db - def get_block(self): """ Passthrough helper to the current VM class. diff --git a/evm/logic/call.py b/evm/logic/call.py index 8d06f6c10f..0c7a3512cb 100644 --- a/evm/logic/call.py +++ b/evm/logic/call.py @@ -56,9 +56,8 @@ def __call__(self, computation): computation.gas_meter.consume_gas(child_msg_gas_fee, reason=self.mnemonic) # Pre-call checks - sender_balance = computation.state_db.get_balance( - computation.msg.storage_address, - ) + with computation.vm.state_db(read_only=True) as state_db: + sender_balance = state_db.get_balance(computation.msg.storage_address) insufficient_funds = should_transfer_value and sender_balance < value stack_too_deep = computation.msg.depth + 1 > constants.STACK_DEPTH_LIMIT @@ -82,10 +81,11 @@ def __call__(self, computation): computation.gas_meter.return_gas(child_msg_gas) computation.stack.push(0) else: - if code_address: - code = computation.state_db.get_code(code_address) - else: - code = computation.state_db.get_code(to) + with computation.vm.state_db(read_only=True) as state_db: + if code_address: + code = state_db.get_code(code_address) + else: + code = state_db.get_code(to) child_msg_kwargs = { 'gas': child_msg_gas, @@ -123,7 +123,8 @@ def __call__(self, computation): class Call(BaseCall): def compute_msg_extra_gas(self, computation, gas, to, value): - account_exists = computation.state_db.account_exists(to) + with computation.vm.state_db(read_only=True) as state_db: + account_exists = state_db.account_exists(to) transfer_gas_fee = constants.GAS_CALLVALUE if value else 0 create_gas_fee = constants.GAS_NEWACCOUNT if not account_exists else 0 return transfer_gas_fee + create_gas_fee diff --git a/evm/logic/context.py b/evm/logic/context.py index 9402e5b58d..9da4e2a26f 100644 --- a/evm/logic/context.py +++ b/evm/logic/context.py @@ -13,7 +13,8 @@ def balance(computation): addr = force_bytes_to_address(computation.stack.pop(type_hint=constants.BYTES)) - balance = computation.state_db.get_balance(addr) + with computation.vm.state_db(read_only=True) as state_db: + balance = state_db.get_balance(addr) computation.stack.push(balance) @@ -107,7 +108,8 @@ def gasprice(computation): def extcodesize(computation): account = force_bytes_to_address(computation.stack.pop(type_hint=constants.BYTES)) - code_size = len(computation.state_db.get_code(account)) + with computation.vm.state_db(read_only=True) as state_db: + code_size = len(state_db.get_code(account)) computation.stack.push(code_size) @@ -130,7 +132,8 @@ def extcodecopy(computation): reason='EXTCODECOPY: word gas cost', ) - code = computation.state_db.get_code(account) + with computation.vm.state_db(read_only=True) as state_db: + code = state_db.get_code(account) code_bytes = code[code_start_position:code_start_position + size] padded_code_bytes = pad_right(code_bytes, size, b'\x00') diff --git a/evm/logic/storage.py b/evm/logic/storage.py index 6bf56d6909..6712036a6b 100644 --- a/evm/logic/storage.py +++ b/evm/logic/storage.py @@ -8,10 +8,11 @@ def sstore(computation): slot, value = computation.stack.pop(num_items=2, type_hint=constants.UINT256) - current_value = computation.state_db.get_storage( - address=computation.msg.storage_address, - slot=slot, - ) + with computation.vm.state_db(read_only=True) as state_db: + current_value = state_db.get_storage( + address=computation.msg.storage_address, + slot=slot, + ) is_currently_empty = not bool(current_value) is_going_to_be_empty = not bool(value) @@ -42,18 +43,20 @@ def sstore(computation): if gas_refund: computation.gas_meter.refund_gas(gas_refund) - computation.state_db.set_storage( - address=computation.msg.storage_address, - slot=slot, - value=value, - ) + with computation.vm.state_db() as state_db: + state_db.set_storage( + address=computation.msg.storage_address, + slot=slot, + value=value, + ) def sload(computation): slot = computation.stack.pop(type_hint=constants.UINT256) - value = computation.state_db.get_storage( - address=computation.msg.storage_address, - slot=slot, - ) + with computation.vm.state_db(read_only=True) as state_db: + value = state_db.get_storage( + address=computation.msg.storage_address, + slot=slot, + ) computation.stack.push(value) diff --git a/evm/logic/system.py b/evm/logic/system.py index 84dd0c8550..79e7328c6c 100644 --- a/evm/logic/system.py +++ b/evm/logic/system.py @@ -28,25 +28,24 @@ def suicide(computation): def suicide_eip150(computation): beneficiary = force_bytes_to_address(computation.stack.pop(type_hint=constants.BYTES)) - if not computation.state_db.account_exists(beneficiary): - computation.gas_meter.consume_gas( - constants.GAS_SUICIDE_NEWACCOUNT, reason=mnemonics.SUICIDE) + with computation.vm.state_db(read_only=True) as state_db: + if not state_db.account_exists(beneficiary): + computation.gas_meter.consume_gas( + constants.GAS_SUICIDE_NEWACCOUNT, reason=mnemonics.SUICIDE) _suicide(computation, beneficiary) def _suicide(computation, beneficiary): - local_balance = computation.state_db.get_balance(computation.msg.storage_address) - beneficiary_balance = computation.state_db.get_balance(beneficiary) - - # 1st: Transfer to beneficiary - computation.state_db.set_balance( - beneficiary, - local_balance + beneficiary_balance, - ) - # 2nd: Zero the balance of the address being deleted (must come after - # sending to beneficiary in case the contract named itself as the - # beneficiary. - computation.state_db.set_balance(computation.msg.storage_address, 0) + with computation.vm.state_db() as state_db: + local_balance = state_db.get_balance(computation.msg.storage_address) + beneficiary_balance = state_db.get_balance(beneficiary) + + # 1st: Transfer to beneficiary + state_db.set_balance(beneficiary, local_balance + beneficiary_balance) + # 2nd: Zero the balance of the address being deleted (must come after + # sending to beneficiary in case the contract named itself as the + # beneficiary. + state_db.set_balance(computation.msg.storage_address, 0) # 3rd: Register the account to be deleted computation.register_account_for_deletion(computation.msg.storage_address) @@ -66,8 +65,9 @@ def __call__(self, computation): computation.extend_memory(start_position, size) - insufficient_funds = computation.state_db.get_balance( - computation.msg.storage_address) < value + with computation.vm.state_db(read_only=True) as state_db: + insufficient_funds = state_db.get_balance( + computation.msg.storage_address) < value stack_too_deep = computation.msg.depth + 1 > constants.STACK_DEPTH_LIMIT if insufficient_funds or stack_too_deep: @@ -80,8 +80,8 @@ def __call__(self, computation): computation.gas_meter.gas_remaining) computation.gas_meter.consume_gas(create_msg_gas, reason="CREATE") - creation_nonce = computation.state_db.get_nonce( - computation.msg.storage_address) + with computation.vm.state_db(read_only=True) as state_db: + creation_nonce = state_db.get_nonce(computation.msg.storage_address) contract_address = generate_contract_address( computation.msg.storage_address, creation_nonce) diff --git a/evm/utils/fixture_tests.py b/evm/utils/fixture_tests.py index 25970142ff..1b96f4e431 100644 --- a/evm/utils/fixture_tests.py +++ b/evm/utils/fixture_tests.py @@ -372,7 +372,6 @@ def setup_state_db(desired_state, state_db): state_db.set_nonce(account, nonce) state_db.set_code(account, code) state_db.set_balance(account, balance) - return state_db def verify_state_db(expected_state, state_db): diff --git a/evm/vm/base.py b/evm/vm/base.py index 0599707462..ca669c0c79 100644 --- a/evm/vm/base.py +++ b/evm/vm/base.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +from contextlib import contextmanager import logging from evm.constants import ( @@ -56,19 +57,17 @@ def configure(cls, ) return type(name, (cls,), overrides) - _block = None - state_db = None - - @property - def block(self): - if self._block is None: - raise AttributeError("No block property set") - return self._block - - @block.setter - def block(self, value): - self._block = value - self.state_db = State(db=self.db, root_hash=value.header.state_root) + @contextmanager + def state_db(self, read_only=False): + state = State(db=self.db, root_hash=self.block.header.state_root) + yield state + if read_only: + # TODO: This is a bit of a hack; ideally we should raise an error whenever the + # callsite tries to call a State method that modifies it. + assert state.root_hash == self.block.header.state_root + elif self.block.header.state_root != state.root_hash: + self.logger.debug("Updating block's state_root to %s", state.root_hash) + self.block.header.state_root = state.root_hash # # Logging @@ -85,12 +84,7 @@ def apply_transaction(self, transaction): Apply the transaction to the vm in the current block. """ computation = self.execute_transaction(transaction) - # NOTE: mutation. Needed in order to update self.state_db, so we should be able to get rid - # of this once we fix https://github.com/pipermerriam/py-evm/issues/67 - self.block = self.block.add_transaction( - transaction=transaction, - computation=computation, - ) + self.block.add_transaction(transaction, computation) return computation def execute_transaction(self, transaction): @@ -151,12 +145,15 @@ def mine_block(self, *args, **kwargs): block = self.block block.mine(*args, **kwargs) - if block.number > 0: - block_reward = self.get_block_reward(block.number) + ( - len(block.uncles) * self.get_nephew_reward(block.number) - ) + if block.number == 0: + return block + + block_reward = self.get_block_reward(block.number) + ( + len(block.uncles) * self.get_nephew_reward(block.number) + ) - self.state_db.delta_balance(block.header.coinbase, block_reward) + with self.state_db() as state_db: + state_db.delta_balance(block.header.coinbase, block_reward) self.logger.debug( "BLOCK REWARD: %s -> %s", block_reward, @@ -167,17 +164,13 @@ def mine_block(self, *args, **kwargs): uncle_reward = BLOCK_REWARD * ( UNCLE_DEPTH_PENALTY_FACTOR + uncle.block_number - block.number ) // UNCLE_DEPTH_PENALTY_FACTOR - self.state_db.delta_balance(uncle.coinbase, uncle_reward) + state_db.delta_balance(uncle.coinbase, uncle_reward) self.logger.debug( "UNCLE REWARD REWARD: %s -> %s", uncle_reward, uncle.coinbase, ) - self.logger.debug('BEFORE ROOT: %s', block.header.state_root) - block.header.state_root = self.state_db.root_hash - self.logger.debug('STATE_ROOT: %s', block.header.state_root) - return block # @@ -264,7 +257,8 @@ def snapshot(self): TODO: This needs to do more than just snapshot the state_db but this is a start. """ - return self.state_db.snapshot() + with self.state_db(read_only=True) as state_db: + return state_db.snapshot() def revert(self, snapshot): """ @@ -272,7 +266,8 @@ def revert(self, snapshot): TODO: This needs to do more than just snapshot the state_db but this is a start. """ - return self.state_db.revert(snapshot) + with self.state_db() as state_db: + return state_db.revert(snapshot) # # Opcode API diff --git a/evm/vm/computation.py b/evm/vm/computation.py index 061e9a2970..30808b3b0a 100644 --- a/evm/vm/computation.py +++ b/evm/vm/computation.py @@ -95,13 +95,6 @@ def is_origin_computation(self): """ return self.msg.is_origin - @property - def state_db(self): - """ - Convenience access to the state database - """ - return self.vm.state_db - # # Execution # diff --git a/evm/vm/flavors/frontier/__init__.py b/evm/vm/flavors/frontier/__init__.py index 3475dc5c06..0a8a8c5409 100644 --- a/evm/vm/flavors/frontier/__init__.py +++ b/evm/vm/flavors/frontier/__init__.py @@ -55,28 +55,29 @@ def _execute_frontier_transaction(vm, transaction): vm.validate_transaction(transaction) gas_cost = transaction.gas * transaction.gas_price - sender_balance = vm.state_db.get_balance(transaction.sender) + with vm.state_db() as state_db: + sender_balance = state_db.get_balance(transaction.sender) - # Buy Gas - vm.state_db.set_balance(transaction.sender, sender_balance - gas_cost) + # Buy Gas + state_db.set_balance(transaction.sender, sender_balance - gas_cost) - # Increment Nonce - vm.state_db.increment_nonce(transaction.sender) + # Increment Nonce + state_db.increment_nonce(transaction.sender) - # Setup VM Message - message_gas = transaction.gas - transaction.intrensic_gas + # Setup VM Message + message_gas = transaction.gas - transaction.intrensic_gas - if transaction.to == constants.CREATE_CONTRACT_ADDRESS: - contract_address = generate_contract_address( - transaction.sender, - vm.state_db.get_nonce(transaction.sender) - 1, - ) - data = b'' - code = transaction.data - else: - contract_address = None - data = transaction.data - code = vm.state_db.get_code(transaction.to) + if transaction.to == constants.CREATE_CONTRACT_ADDRESS: + contract_address = generate_contract_address( + transaction.sender, + state_db.get_nonce(transaction.sender) - 1, + ) + data = b'' + code = transaction.data + else: + contract_address = None + data = transaction.data + code = state_db.get_code(transaction.to) if vm.logger: vm.logger.info( @@ -122,11 +123,12 @@ def _execute_frontier_transaction(vm, transaction): transaction_fee = transaction.gas * transaction.gas_price if vm.logger: vm.logger.debug('TRANSACTION FEE: %s', transaction_fee) - coinbase_balance = vm.state_db.get_balance(vm.block.header.coinbase) - vm.state_db.set_balance( - vm.block.header.coinbase, - coinbase_balance + transaction_fee, - ) + with vm.state_db() as state_db: + coinbase_balance = state_db.get_balance(vm.block.header.coinbase) + state_db.set_balance( + vm.block.header.coinbase, + coinbase_balance + transaction_fee, + ) else: # Suicide Refunds num_deletions = len(computation.get_accounts_for_deletion()) @@ -148,8 +150,9 @@ def _execute_frontier_transaction(vm, transaction): encode_hex(message.sender), ) - sender_balance = vm.state_db.get_balance(message.sender) - vm.state_db.set_balance(message.sender, sender_balance + gas_refund_amount) + with vm.state_db() as state_db: + sender_balance = state_db.get_balance(message.sender) + state_db.set_balance(message.sender, sender_balance + gas_refund_amount) # Miner Fees transaction_fee = (transaction.gas - gas_remaining - gas_refund) * transaction.gas_price @@ -159,11 +162,12 @@ def _execute_frontier_transaction(vm, transaction): transaction_fee, encode_hex(vm.block.header.coinbase), ) - coinbase_balance = vm.state_db.get_balance(vm.block.header.coinbase) - vm.state_db.set_balance( - vm.block.header.coinbase, - coinbase_balance + transaction_fee, - ) + with vm.state_db() as state_db: + coinbase_balance = state_db.get_balance(vm.block.header.coinbase) + state_db.set_balance( + vm.block.header.coinbase, + coinbase_balance + transaction_fee, + ) # Suicides for account, beneficiary in computation.get_accounts_for_deletion(): @@ -172,10 +176,9 @@ def _execute_frontier_transaction(vm, transaction): if vm.logger is not None: vm.logger.debug('DELETING ACCOUNT: %s', encode_hex(account)) - vm.state_db.set_balance(account, 0) - vm.state_db.delete_account(account) - - vm.block.header.state_root = vm.state_db.root_hash + with vm.state_db() as state_db: + state_db.set_balance(account, 0) + state_db.delete_account(account) return computation @@ -187,19 +190,20 @@ def _apply_frontier_message(vm, message): raise StackDepthLimit("Stack depth limit reached") if message.should_transfer_value and message.value: - sender_balance = vm.state_db.get_balance(message.sender) + with vm.state_db() as state_db: + sender_balance = state_db.get_balance(message.sender) - if sender_balance < message.value: - raise InsufficientFunds( - "Insufficient funds: {0} < {1}".format(sender_balance, message.value) - ) + if sender_balance < message.value: + raise InsufficientFunds( + "Insufficient funds: {0} < {1}".format(sender_balance, message.value) + ) - sender_balance -= message.value - vm.state_db.set_balance(message.sender, sender_balance) + sender_balance -= message.value + state_db.set_balance(message.sender, sender_balance) - recipient_balance = vm.state_db.get_balance(message.storage_address) - recipient_balance += message.value - vm.state_db.set_balance(message.storage_address, recipient_balance) + recipient_balance = state_db.get_balance(message.storage_address) + recipient_balance += message.value + state_db.set_balance(message.storage_address, recipient_balance) if vm.logger is not None: vm.logger.debug( @@ -209,8 +213,9 @@ def _apply_frontier_message(vm, message): encode_hex(message.storage_address), ) - if not vm.state_db.account_exists(message.storage_address): - vm.state_db.touch_account(message.storage_address) + with vm.state_db() as state_db: + if not state_db.account_exists(message.storage_address): + state_db.touch_account(message.storage_address) computation = vm.apply_computation(message) @@ -246,14 +251,15 @@ def _apply_frontier_computation(vm, message): def _apply_frontier_create_message(vm, message): - if vm.state_db.get_balance(message.storage_address) > 0: - vm.state_db.set_nonce(message.storage_address, 0) - vm.state_db.set_code(message.storage_address, b'') - # TODO: figure out whether the following line is correct. - vm.state_db.delete_storage(message.storage_address) + with vm.state_db() as state_db: + if state_db.get_balance(message.storage_address) > 0: + state_db.set_nonce(message.storage_address, 0) + state_db.set_code(message.storage_address, b'') + # TODO: figure out whether the following line is correct. + state_db.delete_storage(message.storage_address) - if message.sender != message.origin: - vm.state_db.increment_nonce(message.sender) + if message.sender != message.origin: + state_db.increment_nonce(message.sender) computation = vm.apply_message(message) @@ -278,7 +284,8 @@ def _apply_frontier_create_message(vm, message): encode_hex(message.storage_address), contract_code, ) - computation.state_db.set_code(message.storage_address, contract_code) + with vm.state_db() as state_db: + state_db.set_code(message.storage_address, contract_code) return computation diff --git a/evm/vm/flavors/frontier/blocks.py b/evm/vm/flavors/frontier/blocks.py index 3bc3782f1c..069f4360a6 100644 --- a/evm/vm/flavors/frontier/blocks.py +++ b/evm/vm/flavors/frontier/blocks.py @@ -278,7 +278,7 @@ def add_transaction(self, transaction, computation): gas_used = self.header.gas_used + tx_gas_used receipt = Receipt( - state_root=computation.state_db.root_hash, + state_root=self.header.state_root, gas_used=gas_used, logs=logs, ) @@ -295,7 +295,6 @@ def add_transaction(self, transaction, computation): self.bloom_filter |= receipt.bloom self.header.transaction_root = self.transaction_db.root_hash - self.header.state_root = computation.state_db.root_hash self.header.receipt_root = self.receipt_db.root_hash self.header.bloom = int(self.bloom_filter) self.header.gas_used = gas_used diff --git a/evm/vm/flavors/frontier/validation.py b/evm/vm/flavors/frontier/validation.py index 50ca56b8e1..6ce5b43170 100644 --- a/evm/vm/flavors/frontier/validation.py +++ b/evm/vm/flavors/frontier/validation.py @@ -5,7 +5,8 @@ def validate_frontier_transaction(vm, transaction): gas_cost = transaction.gas * transaction.gas_price - sender_balance = vm.state_db.get_balance(transaction.sender) + with vm.state_db(read_only=True) as state_db: + sender_balance = state_db.get_balance(transaction.sender) if sender_balance < gas_cost: raise ValidationError( @@ -20,5 +21,6 @@ def validate_frontier_transaction(vm, transaction): if vm.block.header.gas_used + transaction.gas > vm.block.header.gas_limit: raise ValidationError("Transaction exceeds gas limit") - if vm.state_db.get_nonce(transaction.sender) != transaction.nonce: - raise ValidationError("Invalid transaction nonce") + with vm.state_db(read_only=True) as state_db: + if state_db.get_nonce(transaction.sender) != transaction.nonce: + raise ValidationError("Invalid transaction nonce") diff --git a/evm/vm/flavors/homestead/__init__.py b/evm/vm/flavors/homestead/__init__.py index 258d4ad6d2..159708ba6d 100644 --- a/evm/vm/flavors/homestead/__init__.py +++ b/evm/vm/flavors/homestead/__init__.py @@ -19,13 +19,14 @@ def _apply_homestead_create_message(vm, message): - if vm.state_db.account_exists(message.storage_address): - vm.state_db.set_nonce(message.storage_address, 0) - vm.state_db.set_code(message.storage_address, b'') - vm.state_db.delete_storage(message.storage_address) + with vm.state_db() as state_db: + if state_db.account_exists(message.storage_address): + state_db.set_nonce(message.storage_address, 0) + state_db.set_code(message.storage_address, b'') + state_db.delete_storage(message.storage_address) - if message.sender != message.origin: - vm.state_db.increment_nonce(message.sender) + if message.sender != message.origin: + state_db.increment_nonce(message.sender) snapshot = vm.snapshot() @@ -55,7 +56,8 @@ def _apply_homestead_create_message(vm, message): encode_hex(message.storage_address), contract_code, ) - computation.state_db.set_code(message.storage_address, contract_code) + with vm.state_db() as state_db: + state_db.set_code(message.storage_address, contract_code) return computation diff --git a/evm/vm/flavors/homestead/headers.py b/evm/vm/flavors/homestead/headers.py index c1ed98f472..73f8415f29 100644 --- a/evm/vm/flavors/homestead/headers.py +++ b/evm/vm/flavors/homestead/headers.py @@ -68,11 +68,12 @@ def configure_homestead_header(vm, **header_params): # there we'd need to manually instantiate the State and update # header.state_root after we're done. if vm.support_dao_fork and header.block_number == vm.dao_fork_block_number: - for account in dao_drain_list: - account = decode_hex(account) - balance = vm.state_db.get_balance(account) - vm.state_db.delta_balance(dao_refund_contract, balance) - vm.state_db.set_balance(account, 0) + with vm.state_db() as state_db: + for account in dao_drain_list: + account = decode_hex(account) + balance = state_db.get_balance(account) + state_db.delta_balance(dao_refund_contract, balance) + state_db.set_balance(account, 0) return header diff --git a/tests/core/chain-object/test_chain.py b/tests/core/chain-object/test_chain.py index 8c0142c892..1284c7ae35 100644 --- a/tests/core/chain-object/test_chain.py +++ b/tests/core/chain-object/test_chain.py @@ -20,11 +20,12 @@ def test_import_block_validation(chain): # noqa: F811 tx = imported_block.transactions[0] assert tx.value == 10 vm = chain.get_vm() - assert vm.state_db.get_balance( - decode_hex("095e7baea6a6c7c4c2dfeb977efac326af552d87")) == tx.value - tx_gas = tx.gas_price * constants.GAS_TX - assert vm.state_db.get_balance(chain.funded_address) == ( - chain.funded_address_initial_balance - tx.value - tx_gas) + with vm.state_db(read_only=True) as state_db: + assert state_db.get_balance( + decode_hex("095e7baea6a6c7c4c2dfeb977efac326af552d87")) == tx.value + tx_gas = tx.gas_price * constants.GAS_TX + assert state_db.get_balance(chain.funded_address) == ( + chain.funded_address_initial_balance - tx.value - tx_gas) def test_import_block(chain_without_block_validation): # noqa: F811 diff --git a/tests/core/helpers.py b/tests/core/helpers.py index d31a4a1461..bcbe510514 100644 --- a/tests/core/helpers.py +++ b/tests/core/helpers.py @@ -4,7 +4,8 @@ def new_transaction(vm, from_, to, amount, private_key, gas_price=10, gas=100000 The transaction will be signed with the given private key. """ - nonce = vm.state_db.get_nonce(from_) + with vm.state_db(read_only=True) as state_db: + nonce = state_db.get_nonce(from_) tx = vm.create_unsigned_transaction( nonce=nonce, gas_price=gas_price, gas=gas, to=to, value=amount, data=b'') return tx.as_signed_transaction(private_key) diff --git a/tests/core/vm/test_vm.py b/tests/core/vm/test_vm.py index f08c130334..39047d4169 100644 --- a/tests/core/vm/test_vm.py +++ b/tests/core/vm/test_vm.py @@ -1,3 +1,5 @@ +import pytest + from eth_utils import decode_hex from evm import constants @@ -17,21 +19,21 @@ def test_apply_transaction(chain_without_block_validation): # noqa: F811 computation = vm.apply_transaction(tx) assert computation.error is None tx_gas = tx.gas_price * constants.GAS_TX - assert vm.state_db.get_balance(from_) == ( - chain.funded_address_initial_balance - amount - tx_gas) - assert vm.state_db.get_balance(recipient) == amount + with vm.state_db(read_only=True) as state_db: + assert state_db.get_balance(from_) == ( + chain.funded_address_initial_balance - amount - tx_gas) + assert state_db.get_balance(recipient) == amount block = vm.block assert block.transactions[tx_idx] == tx assert block.header.gas_used == constants.GAS_TX - assert block.header.state_root == computation.state_db.root_hash def test_mine_block(chain_without_block_validation): # noqa: F811 chain = chain_without_block_validation # noqa: F811 vm = chain.get_vm() block = vm.mine_block() - assert vm.state_db.get_balance(block.header.coinbase) == constants.BLOCK_REWARD - assert block.header.state_root == vm.state_db.root_hash + with vm.state_db(read_only=True) as state_db: + assert state_db.get_balance(block.header.coinbase) == constants.BLOCK_REWARD def test_import_block(chain_without_block_validation): # noqa: F811 @@ -46,3 +48,21 @@ def test_import_block(chain_without_block_validation): # noqa: F811 parent_vm = chain.get_parent_chain(vm.block).get_vm() block = parent_vm.import_block(vm.block) assert block.transactions == [tx] + + +def test_state_db(chain_without_block_validation): # noqa: F811 + vm = chain_without_block_validation.get_vm() + address = decode_hex('0xa94f5374fce5edbc8e2a8697c15331677e6ebf0c') + initial_state_root = vm.block.header.state_root + + with vm.state_db(read_only=True) as state_db: + state_db.get_balance(address) + assert vm.block.header.state_root == initial_state_root + + with vm.state_db() as state_db: + state_db.set_balance(address, 10) + assert vm.block.header.state_root != initial_state_root + + with pytest.raises(AssertionError): + with vm.state_db(read_only=True) as state_db: + state_db.set_balance(address, 0) diff --git a/tests/json-fixtures/test_blockchain.py b/tests/json-fixtures/test_blockchain.py index bf9ec903d9..351bfbe80a 100644 --- a/tests/json-fixtures/test_blockchain.py +++ b/tests/json-fixtures/test_blockchain.py @@ -189,4 +189,5 @@ def test_blockchain_fixtures(fixture_name, fixture): latest_block_hash = chain.get_canonical_block_by_number(chain.get_block().number - 1).hash assert latest_block_hash == fixture['lastblockhash'] - verify_state_db(fixture['postState'], chain.get_state_db()) + with chain.get_vm().state_db(read_only=True) as state_db: + verify_state_db(fixture['postState'], state_db) diff --git a/tests/json-fixtures/test_state.py b/tests/json-fixtures/test_state.py index 2e6a2a1e48..faa0ca542f 100644 --- a/tests/json-fixtures/test_state.py +++ b/tests/json-fixtures/test_state.py @@ -163,8 +163,8 @@ def test_state_fixtures(fixture_name, fixture): db = get_db_backend() chain = ChainForTesting(db=db, header=header) - state_db = setup_state_db(fixture['pre'], chain.get_state_db()) - chain.header.state_root = state_db.root_hash + with chain.get_vm().state_db() as state_db: + setup_state_db(fixture['pre'], state_db) unsigned_transaction = chain.create_unsigned_transaction( nonce=fixture['transaction']['nonce'], @@ -202,4 +202,5 @@ def test_state_fixtures(fixture_name, fixture): else: assert computation.output == expected_output - verify_state_db(fixture['post'], chain.get_state_db()) + with chain.get_vm().state_db(read_only=True) as state_db: + verify_state_db(fixture['post'], state_db) diff --git a/tests/json-fixtures/test_vm.py b/tests/json-fixtures/test_vm.py index 7e87dc99dd..2ab282bc79 100644 --- a/tests/json-fixtures/test_vm.py +++ b/tests/json-fixtures/test_vm.py @@ -142,8 +142,10 @@ def test_vm_fixtures(fixture_name, fixture): timestamp=fixture['env']['currentTimestamp'], ) chain = ChainForTesting(db=db, header=header) - state_db = setup_state_db(fixture['pre'], chain.get_state_db()) - chain.header.state_root = state_db.root_hash + vm = chain.get_vm() + with vm.state_db() as state_db: + setup_state_db(fixture['pre'], state_db) + code = state_db.get_code(fixture['exec']['address']) message = Message( origin=fixture['exec']['origin'], @@ -151,11 +153,10 @@ def test_vm_fixtures(fixture_name, fixture): sender=fixture['exec']['caller'], value=fixture['exec']['value'], data=fixture['exec']['data'], - code=chain.get_state_db().get_code(fixture['exec']['address']), + code=code, gas=fixture['exec']['gas'], gas_price=fixture['exec']['gasPrice'], ) - vm = chain.get_vm() computation = vm.apply_computation(message) if 'post' in fixture: @@ -207,4 +208,5 @@ def test_vm_fixtures(fixture_name, fixture): assert isinstance(computation.error, VMError) post_state = fixture['pre'] - verify_state_db(post_state, vm.state_db) + with vm.state_db(read_only=True) as state_db: + verify_state_db(post_state, state_db)