From 3628e375a6a357e8774bf788343a658b1e60240d Mon Sep 17 00:00:00 2001 From: Sam Wilson Date: Wed, 1 May 2024 09:38:10 -0400 Subject: [PATCH] Remove ensure from prague --- src/ethereum/prague/fork.py | 109 +++++++++--------- src/ethereum/prague/state.py | 4 +- src/ethereum/prague/trie.py | 5 +- .../prague/vm/instructions/environment.py | 9 +- src/ethereum/prague/vm/instructions/log.py | 4 +- src/ethereum/prague/vm/instructions/stack.py | 11 +- .../prague/vm/instructions/storage.py | 16 +-- src/ethereum/prague/vm/instructions/system.py | 19 ++- src/ethereum/prague/vm/interpreter.py | 7 +- .../vm/precompiled_contracts/alt_bn128.py | 13 +-- .../vm/precompiled_contracts/blake2f.py | 11 +- .../precompiled_contracts/point_evaluation.py | 17 +-- 12 files changed, 96 insertions(+), 129 deletions(-) diff --git a/src/ethereum/prague/fork.py b/src/ethereum/prague/fork.py index 450f4a32e7..29ce1958e3 100644 --- a/src/ethereum/prague/fork.py +++ b/src/ethereum/prague/fork.py @@ -19,7 +19,6 @@ from ethereum.crypto.elliptic_curve import SECP256K1N, secp256k1_recover from ethereum.crypto.hash import Hash32, keccak256 from ethereum.exceptions import InvalidBlock -from ethereum.utils.ensure import ensure from .. import rlp from ..base_types import U64, U256, Bytes, Uint @@ -178,10 +177,12 @@ def state_transition(chain: BlockChain, block: Block) -> None: """ parent_header = chain.blocks[-1].header excess_blob_gas = calculate_excess_blob_gas(parent_header) - ensure(block.header.excess_blob_gas == excess_blob_gas, InvalidBlock) + if block.header.excess_blob_gas != excess_blob_gas: + raise InvalidBlock validate_header(block.header, parent_header) - ensure(block.ommers == (), InvalidBlock) + if block.ommers != (): + raise InvalidBlock apply_body_output = apply_body( chain.state, get_last_256_block_hashes(chain), @@ -197,31 +198,20 @@ def state_transition(chain: BlockChain, block: Block) -> None: block.header.parent_beacon_block_root, excess_blob_gas, ) - ensure( - apply_body_output.block_gas_used == block.header.gas_used, InvalidBlock - ) - ensure( - apply_body_output.transactions_root == block.header.transactions_root, - InvalidBlock, - ) - ensure( - apply_body_output.state_root == block.header.state_root, InvalidBlock - ) - ensure( - apply_body_output.receipt_root == block.header.receipt_root, - InvalidBlock, - ) - ensure( - apply_body_output.block_logs_bloom == block.header.bloom, InvalidBlock - ) - ensure( - apply_body_output.withdrawals_root == block.header.withdrawals_root, - InvalidBlock, - ) - ensure( - apply_body_output.blob_gas_used == block.header.blob_gas_used, - InvalidBlock, - ) + if apply_body_output.block_gas_used != block.header.gas_used: + raise InvalidBlock + if apply_body_output.transactions_root != block.header.transactions_root: + raise InvalidBlock + if apply_body_output.state_root != block.header.state_root: + raise InvalidBlock + if apply_body_output.receipt_root != block.header.receipt_root: + raise InvalidBlock + if apply_body_output.block_logs_bloom != block.header.bloom: + raise InvalidBlock + if apply_body_output.withdrawals_root != block.header.withdrawals_root: + raise InvalidBlock + if apply_body_output.blob_gas_used != block.header.blob_gas_used: + raise InvalidBlock chain.blocks.append(block) if len(chain.blocks) > 255: @@ -256,11 +246,8 @@ def calculate_base_fee_per_gas( Base fee per gas for the block. """ parent_gas_target = parent_gas_limit // ELASTICITY_MULTIPLIER - - ensure( - check_gas_limit(block_gas_limit, parent_gas_limit), - InvalidBlock, - ) + if not check_gas_limit(block_gas_limit, parent_gas_limit): + raise InvalidBlock if parent_gas_used == parent_gas_target: expected_base_fee_per_gas = parent_base_fee_per_gas @@ -313,7 +300,8 @@ def validate_header(header: Header, parent_header: Header) -> None: parent_header : Parent Header of the header to check for correctness """ - ensure(header.gas_used <= header.gas_limit, InvalidBlock) + if header.gas_used > header.gas_limit: + raise InvalidBlock expected_base_fee_per_gas = calculate_base_fee_per_gas( header.gas_limit, @@ -321,19 +309,24 @@ def validate_header(header: Header, parent_header: Header) -> None: parent_header.gas_used, parent_header.base_fee_per_gas, ) - - ensure(expected_base_fee_per_gas == header.base_fee_per_gas, InvalidBlock) - - ensure(header.timestamp > parent_header.timestamp, InvalidBlock) - ensure(header.number == parent_header.number + 1, InvalidBlock) - ensure(len(header.extra_data) <= 32, InvalidBlock) - - ensure(header.difficulty == 0, InvalidBlock) - ensure(header.nonce == b"\x00\x00\x00\x00\x00\x00\x00\x00", InvalidBlock) - ensure(header.ommers_hash == EMPTY_OMMER_HASH, InvalidBlock) + if expected_base_fee_per_gas != header.base_fee_per_gas: + raise InvalidBlock + if header.timestamp <= parent_header.timestamp: + raise InvalidBlock + if header.number != parent_header.number + 1: + raise InvalidBlock + if len(header.extra_data) > 32: + raise InvalidBlock + if header.difficulty != 0: + raise InvalidBlock + if header.nonce != b"\x00\x00\x00\x00\x00\x00\x00\x00": + raise InvalidBlock + if header.ommers_hash != EMPTY_OMMER_HASH: + raise InvalidBlock block_parent_hash = keccak256(rlp.encode(parent_header)) - ensure(header.parent_hash == block_parent_hash, InvalidBlock) + if header.parent_hash != block_parent_hash: + raise InvalidBlock def check_transaction( @@ -423,10 +416,12 @@ def check_transaction( blob_versioned_hashes = tx.blob_versioned_hashes else: blob_versioned_hashes = () - - ensure(sender_account.nonce == tx.nonce, InvalidBlock) - ensure(sender_account.balance >= max_gas_fee + tx.value, InvalidBlock) - ensure(sender_account.code == bytearray(), InvalidBlock) + if sender_account.nonce != tx.nonce: + raise InvalidBlock + if sender_account.balance < max_gas_fee + tx.value: + raise InvalidBlock + if sender_account.code != bytearray(): + raise InvalidBlock return sender, effective_gas_price, blob_versioned_hashes @@ -681,8 +676,8 @@ def apply_body( block_logs += logs blob_gas_used += calculate_total_blob_gas(tx) - - ensure(blob_gas_used <= MAX_BLOB_GAS_PER_BLOCK, InvalidBlock) + if blob_gas_used > MAX_BLOB_GAS_PER_BLOCK: + raise InvalidBlock block_gas_used = block_gas_limit - gas_available block_logs_bloom = logs_bloom(block_logs) @@ -884,9 +879,10 @@ def recover_sender(chain_id: U64, tx: Transaction) -> Address: The address of the account that signed the transaction. """ r, s = tx.r, tx.s - - ensure(0 < r and r < SECP256K1N, InvalidBlock) - ensure(0 < s and s <= SECP256K1N // 2, InvalidBlock) + if 0 >= r or r >= SECP256K1N: + raise InvalidBlock + if 0 >= s or s > SECP256K1N // 2: + raise InvalidBlock if isinstance(tx, LegacyTransaction): v = tx.v @@ -895,9 +891,8 @@ def recover_sender(chain_id: U64, tx: Transaction) -> Address: r, s, v - 27, signing_hash_pre155(tx) ) else: - ensure( - v == 35 + chain_id * 2 or v == 36 + chain_id * 2, InvalidBlock - ) + if v != 35 + chain_id * 2 and v != 36 + chain_id * 2: + raise InvalidBlock public_key = secp256k1_recover( r, s, v - 35 - chain_id * 2, signing_hash_155(tx, chain_id) ) diff --git a/src/ethereum/prague/state.py b/src/ethereum/prague/state.py index 1d0fd8f476..58280b3cf9 100644 --- a/src/ethereum/prague/state.py +++ b/src/ethereum/prague/state.py @@ -20,7 +20,6 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple from ethereum.base_types import U256, Bytes, Uint, modify -from ethereum.utils.ensure import ensure from .blocks import Withdrawal from .fork_types import EMPTY_ACCOUNT, Account, Address, Root @@ -502,7 +501,8 @@ def move_ether( """ def reduce_sender_balance(sender: Account) -> None: - ensure(sender.balance >= amount, AssertionError) + if sender.balance < amount: + raise AssertionError sender.balance -= amount def increase_recipient_balance(recipient: Account) -> None: diff --git a/src/ethereum/prague/trie.py b/src/ethereum/prague/trie.py index 97fc9983d8..120aca5a26 100644 --- a/src/ethereum/prague/trie.py +++ b/src/ethereum/prague/trie.py @@ -31,7 +31,6 @@ from ethereum.cancun import trie as previous_trie from ethereum.crypto.hash import keccak256 -from ethereum.utils.ensure import ensure from ethereum.utils.hexadecimal import hex_to_bytes from .. import rlp @@ -349,8 +348,8 @@ def _prepare_trie( encoded_value = encode_node(value, get_storage_root(address)) else: encoded_value = encode_node(value) - # Empty values are represented by their absence - ensure(encoded_value != b"", AssertionError) + if encoded_value == b"": + raise AssertionError key: Bytes if trie.secured: # "secure" tries hash keys once before construction diff --git a/src/ethereum/prague/vm/instructions/environment.py b/src/ethereum/prague/vm/instructions/environment.py index ca3caca6b4..e1b135a681 100644 --- a/src/ethereum/prague/vm/instructions/environment.py +++ b/src/ethereum/prague/vm/instructions/environment.py @@ -14,7 +14,6 @@ from ethereum.base_types import U256, Bytes32, Uint from ethereum.crypto.hash import keccak256 -from ethereum.utils.ensure import ensure from ethereum.utils.numeric import ceil32 from ...fork_types import EMPTY_ACCOUNT @@ -441,12 +440,8 @@ def returndatacopy(evm: Evm) -> None: evm.memory, [(memory_start_index, size)] ) charge_gas(evm, GAS_VERY_LOW + copy_gas_cost + extend_memory.cost) - - # OPERATION - ensure( - Uint(return_data_start_position) + Uint(size) <= len(evm.return_data), - OutOfBoundsRead, - ) + if Uint(return_data_start_position) + Uint(size) > len(evm.return_data): + raise OutOfBoundsRead evm.memory += b"\x00" * extend_memory.expand_by value = evm.return_data[ diff --git a/src/ethereum/prague/vm/instructions/log.py b/src/ethereum/prague/vm/instructions/log.py index 43c0fbfbd7..ced8461e4a 100644 --- a/src/ethereum/prague/vm/instructions/log.py +++ b/src/ethereum/prague/vm/instructions/log.py @@ -14,7 +14,6 @@ from functools import partial from ethereum.base_types import U256 -from ethereum.utils.ensure import ensure from ...blocks import Log from .. import Evm @@ -68,7 +67,8 @@ def log_n(evm: Evm, num_topics: U256) -> None: # OPERATION evm.memory += b"\x00" * extend_memory.expand_by - ensure(not evm.message.is_static, WriteInStaticContext) + if evm.message.is_static: + raise WriteInStaticContext log_entry = Log( address=evm.message.current_target, topics=tuple(topics), diff --git a/src/ethereum/prague/vm/instructions/stack.py b/src/ethereum/prague/vm/instructions/stack.py index ca1084d1f4..d3c3e92723 100644 --- a/src/ethereum/prague/vm/instructions/stack.py +++ b/src/ethereum/prague/vm/instructions/stack.py @@ -15,7 +15,6 @@ from functools import partial from ethereum.base_types import U256 -from ethereum.utils.ensure import ensure from .. import Evm, stack from ..exceptions import StackUnderflowError @@ -98,9 +97,8 @@ def dup_n(evm: Evm, item_number: int) -> None: # GAS charge_gas(evm, GAS_VERY_LOW) - - # OPERATION - ensure(item_number < len(evm.stack), StackUnderflowError) + if item_number >= len(evm.stack): + raise StackUnderflowError data_to_duplicate = evm.stack[len(evm.stack) - 1 - item_number] stack.push(evm.stack, data_to_duplicate) @@ -131,9 +129,8 @@ def swap_n(evm: Evm, item_number: int) -> None: # GAS charge_gas(evm, GAS_VERY_LOW) - - # OPERATION - ensure(item_number < len(evm.stack), StackUnderflowError) + if item_number >= len(evm.stack): + raise StackUnderflowError evm.stack[-1], evm.stack[-1 - item_number] = ( evm.stack[-1 - item_number], evm.stack[-1], diff --git a/src/ethereum/prague/vm/instructions/storage.py b/src/ethereum/prague/vm/instructions/storage.py index 4c04acc9c7..15c3f49e88 100644 --- a/src/ethereum/prague/vm/instructions/storage.py +++ b/src/ethereum/prague/vm/instructions/storage.py @@ -12,7 +12,6 @@ Implementations of the EVM storage related instructions. """ from ethereum.base_types import Uint -from ethereum.utils.ensure import ensure from ...state import ( get_storage, @@ -78,9 +77,8 @@ def sstore(evm: Evm) -> None: # STACK key = pop(evm.stack).to_be_bytes32() new_value = pop(evm.stack) - - # GAS - ensure(evm.gas_left > GAS_CALL_STIPEND, OutOfGasError) + if evm.gas_left <= GAS_CALL_STIPEND: + raise OutOfGasError original_value = get_storage_original( evm.env.state, evm.message.current_target, key @@ -123,9 +121,8 @@ def sstore(evm: Evm) -> None: ) charge_gas(evm, gas_cost) - - # OPERATION - ensure(not evm.message.is_static, WriteInStaticContext) + if evm.message.is_static: + raise WriteInStaticContext set_storage(evm.env.state, evm.message.current_target, key, new_value) # PROGRAM COUNTER @@ -171,9 +168,8 @@ def tstore(evm: Evm) -> None: # GAS charge_gas(evm, GAS_WARM_ACCESS) - - # OPERATION - ensure(not evm.message.is_static, WriteInStaticContext) + if evm.message.is_static: + raise WriteInStaticContext set_transient_storage( evm.env.transient_storage, evm.message.current_target, key, new_value ) diff --git a/src/ethereum/prague/vm/instructions/system.py b/src/ethereum/prague/vm/instructions/system.py index 07f3ab043f..fa0e181775 100644 --- a/src/ethereum/prague/vm/instructions/system.py +++ b/src/ethereum/prague/vm/instructions/system.py @@ -12,7 +12,6 @@ Implementations of the EVM system related instructions. """ from ethereum.base_types import U256, Bytes0, Uint -from ethereum.utils.ensure import ensure from ethereum.utils.numeric import ceil32 from ...fork_types import Address @@ -79,15 +78,15 @@ def generic_create( call_data = memory_read_bytes( evm.memory, memory_start_position, memory_size ) - - ensure(len(call_data) <= 2 * MAX_CODE_SIZE, OutOfGasError) + if len(call_data) > 2 * MAX_CODE_SIZE: + raise OutOfGasError evm.accessed_addresses.add(contract_address) create_message_gas = max_message_call_gas(Uint(evm.gas_left)) evm.gas_left -= create_message_gas - - ensure(not evm.message.is_static, WriteInStaticContext) + if evm.message.is_static: + raise WriteInStaticContext evm.return_data = b"" sender_address = evm.message.current_target @@ -376,9 +375,8 @@ def call(evm: Evm) -> None: access_gas_cost + create_gas_cost + transfer_gas_cost, ) charge_gas(evm, message_call_gas.cost + extend_memory.cost) - - # OPERATION - ensure(not evm.message.is_static or value == U256(0), WriteInStaticContext) + if evm.message.is_static and value != U256(0): + raise WriteInStaticContext evm.memory += b"\x00" * extend_memory.expand_by sender_balance = get_account( evm.env.state, evm.message.current_target @@ -506,9 +504,8 @@ def selfdestruct(evm: Evm) -> None: gas_cost += GAS_SELF_DESTRUCT_NEW_ACCOUNT charge_gas(evm, gas_cost) - - # OPERATION - ensure(not evm.message.is_static, WriteInStaticContext) + if evm.message.is_static: + raise WriteInStaticContext originator = evm.message.current_target originator_balance = get_account(evm.env.state, originator).balance diff --git a/src/ethereum/prague/vm/interpreter.py b/src/ethereum/prague/vm/interpreter.py index 0207669793..b85965bf40 100644 --- a/src/ethereum/prague/vm/interpreter.py +++ b/src/ethereum/prague/vm/interpreter.py @@ -25,7 +25,6 @@ TransactionEnd, evm_trace, ) -from ethereum.utils.ensure import ensure from ..blocks import Log from ..fork_types import Address @@ -184,9 +183,11 @@ def process_create_message(message: Message, env: Environment) -> Evm: contract_code_gas = len(contract_code) * GAS_CODE_DEPOSIT try: if len(contract_code) > 0: - ensure(contract_code[0] != 0xEF, InvalidContractPrefix) + if contract_code[0] == 0xEF: + raise InvalidContractPrefix charge_gas(evm, contract_code_gas) - ensure(len(contract_code) <= MAX_CODE_SIZE, OutOfGasError) + if len(contract_code) > MAX_CODE_SIZE: + raise OutOfGasError except ExceptionalHalt as error: rollback_transaction(env.state, env.transient_storage) evm.gas_left = Uint(0) diff --git a/src/ethereum/prague/vm/precompiled_contracts/alt_bn128.py b/src/ethereum/prague/vm/precompiled_contracts/alt_bn128.py index 70a9e51b63..4181bb90a9 100644 --- a/src/ethereum/prague/vm/precompiled_contracts/alt_bn128.py +++ b/src/ethereum/prague/vm/precompiled_contracts/alt_bn128.py @@ -22,7 +22,6 @@ BNP2, pairing, ) -from ethereum.utils.ensure import ensure from ...vm import Evm from ...vm.gas import charge_gas @@ -139,14 +138,10 @@ def alt_bn128_pairing_check(evm: Evm) -> None: ) except ValueError: raise OutOfGasError() - ensure( - p.mul_by(ALT_BN128_CURVE_ORDER) == BNP.point_at_infinity(), - OutOfGasError, - ) - ensure( - q.mul_by(ALT_BN128_CURVE_ORDER) == BNP2.point_at_infinity(), - OutOfGasError, - ) + if p.mul_by(ALT_BN128_CURVE_ORDER) != BNP.point_at_infinity(): + raise OutOfGasError + if q.mul_by(ALT_BN128_CURVE_ORDER) != BNP2.point_at_infinity(): + raise OutOfGasError if p != BNP.point_at_infinity() and q != BNP2.point_at_infinity(): result = result * pairing(q, p) diff --git a/src/ethereum/prague/vm/precompiled_contracts/blake2f.py b/src/ethereum/prague/vm/precompiled_contracts/blake2f.py index 6af10909f0..0d86ba6e85 100644 --- a/src/ethereum/prague/vm/precompiled_contracts/blake2f.py +++ b/src/ethereum/prague/vm/precompiled_contracts/blake2f.py @@ -12,7 +12,6 @@ Implementation of the `Blake2` precompiled contract. """ from ethereum.crypto.blake2 import Blake2b -from ethereum.utils.ensure import ensure from ...vm import Evm from ...vm.gas import GAS_BLAKE2_PER_ROUND, charge_gas @@ -29,16 +28,14 @@ def blake2f(evm: Evm) -> None: The current EVM frame. """ data = evm.message.data - - # GAS - ensure(len(data) == 213, InvalidParameter) + if len(data) != 213: + raise InvalidParameter blake2b = Blake2b() rounds, h, m, t_0, t_1, f = blake2b.get_blake2_parameters(data) charge_gas(evm, GAS_BLAKE2_PER_ROUND * rounds) - - # OPERATION - ensure(f in [0, 1], InvalidParameter) + if f not in [0, 1]: + raise InvalidParameter evm.output = blake2b.compress(rounds, h, m, t_0, t_1, f) diff --git a/src/ethereum/prague/vm/precompiled_contracts/point_evaluation.py b/src/ethereum/prague/vm/precompiled_contracts/point_evaluation.py index 841056287f..d02fb83032 100644 --- a/src/ethereum/prague/vm/precompiled_contracts/point_evaluation.py +++ b/src/ethereum/prague/vm/precompiled_contracts/point_evaluation.py @@ -18,7 +18,6 @@ ) from ethereum.base_types import U256, Bytes -from ethereum.utils.ensure import ensure from ...vm import Evm from ...vm.exceptions import KZGProofError @@ -41,8 +40,8 @@ def point_evaluation(evm: Evm) -> None: """ data = evm.message.data - - ensure(len(data) == 192, KZGProofError) + if len(data) != 192: + raise KZGProofError versioned_hash = data[:32] z = data[32:64] @@ -52,13 +51,8 @@ def point_evaluation(evm: Evm) -> None: # GAS charge_gas(evm, GAS_POINT_EVALUATION) - - # OPERATION - # Verify commitment matches versioned_hash - ensure( - kzg_commitment_to_versioned_hash(commitment) == versioned_hash, - KZGProofError, - ) + if kzg_commitment_to_versioned_hash(commitment) != versioned_hash: + raise KZGProofError # Verify KZG proof with z and y in big endian format try: @@ -66,7 +60,8 @@ def point_evaluation(evm: Evm) -> None: except Exception as e: raise KZGProofError from e - ensure(kzg_proof_verification, KZGProofError) + if not kzg_proof_verification: + raise KZGProofError # Return FIELD_ELEMENTS_PER_BLOB and BLS_MODULUS as padded # 32 byte big endian values