diff --git a/src/ethereum/forks/amsterdam/block_access_lists/__init__.py b/src/ethereum/forks/amsterdam/block_access_lists/__init__.py index 294a83ecae..856ab832bc 100644 --- a/src/ethereum/forks/amsterdam/block_access_lists/__init__.py +++ b/src/ethereum/forks/amsterdam/block_access_lists/__init__.py @@ -19,7 +19,10 @@ ) from .tracker import ( StateChangeTracker, - set_transaction_index, + begin_call_frame, + commit_call_frame, + rollback_call_frame, + set_block_access_index, track_address_access, track_balance_change, track_code_change, @@ -37,9 +40,12 @@ "add_storage_read", "add_storage_write", "add_touched_account", + "begin_call_frame", "build_block_access_list", + "commit_call_frame", "compute_block_access_list_hash", - "set_transaction_index", + "rollback_call_frame", + "set_block_access_index", "rlp_encode_block_access_list", "track_address_access", "track_balance_change", diff --git a/src/ethereum/forks/amsterdam/block_access_lists/tracker.py b/src/ethereum/forks/amsterdam/block_access_lists/tracker.py index 7fe8735deb..1ad068a604 100644 --- a/src/ethereum/forks/amsterdam/block_access_lists/tracker.py +++ b/src/ethereum/forks/amsterdam/block_access_lists/tracker.py @@ -16,7 +16,7 @@ """ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, List, Set, Tuple from ethereum_types.bytes import Bytes, Bytes32 from ethereum_types.numeric import U64, U256, Uint @@ -37,6 +37,39 @@ from ..state import State # noqa: F401 +@dataclass +class CallFrameSnapshot: + """ + Snapshot of block access list state for a single call frame. + + Used to track changes within a call frame to enable proper handling + of reverts as specified in EIP-7928. + """ + + touched_addresses: Set[Address] = field(default_factory=set) + """Addresses touched during this call frame.""" + + storage_writes: Dict[Tuple[Address, Bytes32], U256] = field( + default_factory=dict + ) + """Storage writes made during this call frame.""" + + balance_changes: Set[Tuple[Address, BlockAccessIndex, U256]] = field( + default_factory=set + ) + """Balance changes made during this call frame.""" + + nonce_changes: Set[Tuple[Address, BlockAccessIndex, U64]] = field( + default_factory=set + ) + """Nonce changes made during this call frame.""" + + code_changes: Set[Tuple[Address, BlockAccessIndex, Bytes]] = field( + default_factory=set + ) + """Code changes made during this call frame.""" + + @dataclass class StateChangeTracker: """ @@ -70,16 +103,25 @@ class StateChangeTracker: 1..n for transactions, n+1 for post-execution). """ + call_frame_snapshots: List[CallFrameSnapshot] = field(default_factory=list) + """ + Stack of snapshots for nested call frames to handle reverts properly. + """ -def set_transaction_index( + +def set_block_access_index( tracker: StateChangeTracker, block_access_index: Uint ) -> None: """ Set the current block access index for tracking changes. Must be called before processing each transaction/system contract - to ensure changes - are associated with the correct block access index. + to ensure changes are associated with the correct block access index. + + Note: Block access indices differ from transaction indices: + - 0: Pre-execution (system contracts like beacon roots, block hashes) + - 1..n: Transactions (tx at index i gets block_access_index i+1) + - n+1: Post-execution (withdrawals, requests) Parameters ---------- @@ -221,6 +263,10 @@ def track_storage_write( BlockAccessIndex(tracker.current_block_access_index), value_bytes, ) + # Record in current call frame snapshot if exists + if tracker.call_frame_snapshots: + snapshot = tracker.call_frame_snapshots[-1] + snapshot.storage_writes[(address, key)] = new_value else: add_storage_read(tracker.block_access_list_builder, address, key) @@ -249,13 +295,21 @@ def track_balance_change( """ track_address_access(tracker, address) + block_access_index = BlockAccessIndex(tracker.current_block_access_index) add_balance_change( tracker.block_access_list_builder, address, - BlockAccessIndex(tracker.current_block_access_index), + block_access_index, new_balance, ) + # Record in current call frame snapshot if exists + if tracker.call_frame_snapshots: + snapshot = tracker.call_frame_snapshots[-1] + snapshot.balance_changes.add( + (address, block_access_index, new_balance) + ) + def track_nonce_change( tracker: StateChangeTracker, address: Address, new_nonce: Uint @@ -282,13 +336,20 @@ def track_nonce_change( [`CREATE2`]: ref:ethereum.forks.amsterdam.vm.instructions.system.create2 """ track_address_access(tracker, address) + block_access_index = BlockAccessIndex(tracker.current_block_access_index) + nonce_u64 = U64(new_nonce) add_nonce_change( tracker.block_access_list_builder, address, - BlockAccessIndex(tracker.current_block_access_index), - U64(new_nonce), + block_access_index, + nonce_u64, ) + # Record in current call frame snapshot if exists + if tracker.call_frame_snapshots: + snapshot = tracker.call_frame_snapshots[-1] + snapshot.nonce_changes.add((address, block_access_index, nonce_u64)) + def track_code_change( tracker: StateChangeTracker, address: Address, new_code: Bytes @@ -313,13 +374,19 @@ def track_code_change( [`CREATE2`]: ref:ethereum.forks.amsterdam.vm.instructions.system.create2 """ track_address_access(tracker, address) + block_access_index = BlockAccessIndex(tracker.current_block_access_index) add_code_change( tracker.block_access_list_builder, address, - BlockAccessIndex(tracker.current_block_access_index), + block_access_index, new_code, ) + # Record in current call frame snapshot if exists + if tracker.call_frame_snapshots: + snapshot = tracker.call_frame_snapshots[-1] + snapshot.code_changes.add((address, block_access_index, new_code)) + def finalize_transaction_changes( tracker: StateChangeTracker, state: "State" @@ -339,3 +406,120 @@ def finalize_transaction_changes( The current execution state. """ pass + + +def begin_call_frame(tracker: StateChangeTracker) -> None: + """ + Begin a new call frame for tracking reverts. + + Creates a new snapshot to track changes within this call frame. + This allows proper handling of reverts as specified in EIP-7928. + + Parameters + ---------- + tracker : + The state change tracker instance. + """ + tracker.call_frame_snapshots.append(CallFrameSnapshot()) + + +def rollback_call_frame(tracker: StateChangeTracker) -> None: + """ + Rollback changes from the current call frame. + + When a call reverts, this function: + - Converts storage writes to reads + - Removes balance, nonce, and code changes + - Preserves touched addresses + + This implements EIP-7928 revert handling where reverted writes + become reads and addresses remain in the access list. + + Parameters + ---------- + tracker : + The state change tracker instance. + """ + if not tracker.call_frame_snapshots: + return + + snapshot = tracker.call_frame_snapshots.pop() + builder = tracker.block_access_list_builder + + # Convert storage writes to reads + for (address, slot), _ in snapshot.storage_writes.items(): + # Remove the write from storage_changes + if address in builder.accounts: + account_data = builder.accounts[address] + if slot in account_data.storage_changes: + # Filter out changes from this call frame + account_data.storage_changes[slot] = [ + change + for change in account_data.storage_changes[slot] + if change.block_access_index + != tracker.current_block_access_index + ] + if not account_data.storage_changes[slot]: + del account_data.storage_changes[slot] + # Add as a read instead + account_data.storage_reads.add(slot) + + # Remove balance changes from this call frame + for address, block_access_index, new_balance in snapshot.balance_changes: + if address in builder.accounts: + account_data = builder.accounts[address] + # Filter out balance changes from this call frame + account_data.balance_changes = [ + change + for change in account_data.balance_changes + if not ( + change.block_access_index == block_access_index + and change.post_balance == new_balance + ) + ] + + # Remove nonce changes from this call frame + for address, block_access_index, new_nonce in snapshot.nonce_changes: + if address in builder.accounts: + account_data = builder.accounts[address] + # Filter out nonce changes from this call frame + account_data.nonce_changes = [ + change + for change in account_data.nonce_changes + if not ( + change.block_access_index == block_access_index + and change.new_nonce == new_nonce + ) + ] + + # Remove code changes from this call frame + for address, block_access_index, new_code in snapshot.code_changes: + if address in builder.accounts: + account_data = builder.accounts[address] + # Filter out code changes from this call frame + account_data.code_changes = [ + change + for change in account_data.code_changes + if not ( + change.block_access_index == block_access_index + and change.new_code == new_code + ) + ] + + # All touched addresses remain in the access list (already tracked) + + +def commit_call_frame(tracker: StateChangeTracker) -> None: + """ + Commit changes from the current call frame. + + Removes the current call frame snapshot without rolling back changes. + Called when a call completes successfully. + + Parameters + ---------- + tracker : + The state change tracker instance. + """ + if tracker.call_frame_snapshots: + tracker.call_frame_snapshots.pop() diff --git a/src/ethereum/forks/amsterdam/fork.py b/src/ethereum/forks/amsterdam/fork.py index 569d1c81ec..4ad0cb66d2 100644 --- a/src/ethereum/forks/amsterdam/fork.py +++ b/src/ethereum/forks/amsterdam/fork.py @@ -33,7 +33,7 @@ from .block_access_lists.builder import build_block_access_list from .block_access_lists.rlp_utils import compute_block_access_list_hash from .block_access_lists.tracker import ( - set_transaction_index, + set_block_access_index, track_balance_change, ) from .blocks import Block, Header, Log, Receipt, Withdrawal, encode_receipt @@ -764,9 +764,9 @@ def apply_body( """ block_output = vm.BlockOutput() - # Set system transaction index for pre-execution system contracts + # Set block access index for pre-execution system contracts # EIP-7928: System contracts use block_access_index 0 - set_transaction_index(block_env.state.change_tracker, Uint(0)) + set_block_access_index(block_env.state.change_tracker, Uint(0)) process_unchecked_system_transaction( block_env=block_env, @@ -785,7 +785,9 @@ def apply_body( # EIP-7928: Post-execution uses block_access_index len(transactions) + 1 post_execution_index = ulen(transactions) + Uint(1) - set_transaction_index(block_env.state.change_tracker, post_execution_index) + set_block_access_index( + block_env.state.change_tracker, post_execution_index + ) process_withdrawals(block_env, block_output, withdrawals) @@ -874,7 +876,8 @@ def process_transaction( Index of the transaction in the block. """ # EIP-7928: Transactions use block_access_index 1 to len(transactions) - set_transaction_index(block_env.state.change_tracker, index + Uint(1)) + # Transaction at index i gets block_access_index i+1 + set_block_access_index(block_env.state.change_tracker, index + Uint(1)) trie_set( block_output.transactions_trie, diff --git a/src/ethereum/forks/amsterdam/vm/interpreter.py b/src/ethereum/forks/amsterdam/vm/interpreter.py index fb893aaa6b..f217c8dafd 100644 --- a/src/ethereum/forks/amsterdam/vm/interpreter.py +++ b/src/ethereum/forks/amsterdam/vm/interpreter.py @@ -30,6 +30,12 @@ evm_trace, ) +from ..block_access_lists.tracker import ( + begin_call_frame, + commit_call_frame, + rollback_call_frame, + track_address_access, +) from ..blocks import Log from ..fork_types import Address from ..state import ( @@ -239,6 +245,11 @@ def process_message(message: Message) -> Evm: # take snapshot of state before processing the message begin_transaction(state, transient_storage) + if hasattr(state, 'change_tracker') and state.change_tracker: + begin_call_frame(state.change_tracker) + # Track target address access when processing a message + track_address_access(state.change_tracker, message.current_target) + if message.should_transfer_value and message.value != 0: move_ether( state, message.caller, message.current_target, message.value @@ -249,8 +260,12 @@ def process_message(message: Message) -> Evm: # revert state to the last saved checkpoint # since the message call resulted in an error rollback_transaction(state, transient_storage) + if hasattr(state, 'change_tracker') and state.change_tracker: + rollback_call_frame(state.change_tracker) else: commit_transaction(state, transient_storage) + if hasattr(state, 'change_tracker') and state.change_tracker: + commit_call_frame(state.change_tracker) return evm diff --git a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py index b3a60544fe..042f79555f 100644 --- a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py +++ b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py @@ -137,10 +137,10 @@ def compute_block_access_list_hash(self) -> Any: ) @property - def set_transaction_index(self) -> Any: - """set_transaction_index function of the fork""" + def set_block_access_index(self) -> Any: + """set_block_access_index function of the fork""" return ( - self._module("block_access_lists").set_transaction_index + self._module("block_access_lists").set_block_access_index ) @property diff --git a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py index cdb24194cc..ef5bf5abc8 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py @@ -253,7 +253,7 @@ def run_state_test(self) -> Any: def _run_blockchain_test(self, block_env: Any, block_output: Any) -> None: if self.fork.is_after_fork("ethereum.forks.amsterdam"): - self.fork.set_transaction_index( + self.fork.set_block_access_index( block_env.state.change_tracker, Uint(0) ) if self.fork.is_after_fork("ethereum.forks.prague"): @@ -295,7 +295,7 @@ def _run_blockchain_test(self, block_env: Any, block_output: Any) -> None: # post-execution use n + 1 post_execution_index = num_transactions + Uint(1) - self.fork.set_transaction_index( + self.fork.set_block_access_index( block_env.state.change_tracker, post_execution_index ) diff --git a/tests/amsterdam/test_bal_implementation.py b/tests/amsterdam/test_bal_implementation.py index 988c11dcc6..5320d7ec99 100644 --- a/tests/amsterdam/test_bal_implementation.py +++ b/tests/amsterdam/test_bal_implementation.py @@ -8,8 +8,6 @@ - Edge cases and error handling """ -from unittest.mock import MagicMock, patch - import pytest from ethereum_types.bytes import Bytes, Bytes20, Bytes32 from ethereum_types.numeric import U64, U256, Uint @@ -27,7 +25,7 @@ ) from ethereum.forks.amsterdam.block_access_lists.tracker import ( capture_pre_state, - set_transaction_index, + set_block_access_index, track_balance_change, track_code_change, track_nonce_change, @@ -194,14 +192,14 @@ def test_tracker_initialization(self) -> None: assert tracker.pre_storage_cache == {} assert tracker.current_block_access_index == 0 - def test_tracker_set_transaction_index(self) -> None: + def test_tracker_set_block_access_index(self) -> None: """Test setting block access index.""" builder = BlockAccessListBuilder() tracker = StateChangeTracker(builder) - set_transaction_index(tracker, 5) + set_block_access_index(tracker, 5) assert tracker.current_block_access_index == 5 - # Pre-storage cache should persist across transactions + # Pre-storage cache should be cleared for new block access index assert tracker.pre_storage_cache == {} @patch("ethereum.forks.amsterdam.state.get_storage") @@ -612,5 +610,229 @@ def test_address_sorting(self) -> None: assert account.address == sorted_addresses[i] +class TestValueCalls: + """Test value call scenarios including 0 ETH calls.""" + + def test_zero_eth_value_call_tracks_address_without_balance(self) -> None: + """Test that 0 ETH calls track recipient address without balance changes.""" + from ethereum.forks.amsterdam.block_access_lists.tracker import track_address_access + + builder = BlockAccessListBuilder() + tracker = StateChangeTracker(builder) + set_block_access_index(tracker, Uint(1)) + + recipient = Bytes20(b"\x02" * 20) + + # Track only the address access without balance change + track_address_access(tracker, recipient) + + block_access_list = build_block_access_list(builder) + + # Verify recipient is tracked without balance changes + recipient_found = False + for account in block_access_list.account_changes: + if account.address == recipient: + recipient_found = True + assert len(account.balance_changes) == 0 + break + + assert recipient_found + + def test_nonzero_eth_value_call_tracks_with_balance(self) -> None: + """Test that non-zero ETH calls track addresses with balance changes.""" + builder = BlockAccessListBuilder() + tracker = StateChangeTracker(builder) + set_block_access_index(tracker, Uint(1)) + + sender = Bytes20(b"\x01" * 20) + recipient = Bytes20(b"\x02" * 20) + + # Track balance changes for value transfer + track_balance_change(tracker, sender, U256(900)) + track_balance_change(tracker, recipient, U256(100)) + + block_access_list = build_block_access_list(builder) + + # Verify both addresses tracked with balance changes + sender_found = False + recipient_found = False + + for account in block_access_list.account_changes: + if account.address == sender: + sender_found = True + assert len(account.balance_changes) == 1 + assert account.balance_changes[0].post_balance == U256(900) + elif account.address == recipient: + recipient_found = True + assert len(account.balance_changes) == 1 + assert account.balance_changes[0].post_balance == U256(100) + + assert sender_found and recipient_found + + def test_multiple_zero_eth_calls_deduplication(self) -> None: + """Test that multiple 0 ETH calls to same address are deduplicated.""" + from ethereum.forks.amsterdam.block_access_lists.tracker import track_address_access + + builder = BlockAccessListBuilder() + tracker = StateChangeTracker(builder) + set_block_access_index(tracker, Uint(1)) + + recipient = Bytes20(b"\x02" * 20) + + # Multiple calls to same address + track_address_access(tracker, recipient) + track_address_access(tracker, recipient) + track_address_access(tracker, recipient) + + block_access_list = build_block_access_list(builder) + + # Verify address appears exactly once without balance changes + recipient_count = sum(1 for account in block_access_list.account_changes + if account.address == recipient) + assert recipient_count == 1 + + for account in block_access_list.account_changes: + if account.address == recipient: + assert len(account.balance_changes) == 0 + + +class TestRevertScenarios: + """Test block access list behavior during reverts.""" + + def test_storage_write_becomes_read_on_revert(self) -> None: + """Test that storage writes become reads when transaction reverts.""" + from ethereum.forks.amsterdam.block_access_lists.tracker import ( + begin_call_frame, + rollback_call_frame, + track_storage_write, + track_storage_read, + ) + + builder = BlockAccessListBuilder() + tracker = StateChangeTracker(builder) + set_block_access_index(tracker, Uint(1)) + + address = Bytes20(b"\x01" * 20) + slot1 = Bytes32(b"\x01" * 32) + slot2 = Bytes32(b"\x02" * 32) + + # Begin call frame + begin_call_frame(tracker) + + # Mock state for storage operations + class MockState: + pass + state = MockState() + + # Track storage operations that will be reverted + track_storage_read(tracker, address, slot1, state) # Read slot 0x01 + + # Storage write to slot 0x02 (will be reverted) + track_storage_write(tracker, address, slot2, U256(42), state) + + # Rollback the call frame (simulating revert) + rollback_call_frame(tracker) + + # Build and check the access list + block_access_list = build_block_access_list(builder) + + # Find the account in the access list + account_found = False + for account in block_access_list.account_changes: + if account.address == address: + account_found = True + # Both slots should be in storage_reads + assert slot1 in builder.accounts[address].storage_reads + assert slot2 in builder.accounts[address].storage_reads + # No storage changes should exist + assert len(account.storage_changes) == 0 + break + + assert account_found + + def test_balance_changes_removed_on_revert(self) -> None: + """Test that balance changes are removed on revert but address remains.""" + from ethereum.forks.amsterdam.block_access_lists.tracker import ( + begin_call_frame, + rollback_call_frame, + ) + + builder = BlockAccessListBuilder() + tracker = StateChangeTracker(builder) + set_block_access_index(tracker, Uint(1)) + + address = Bytes20(b"\x01" * 20) + + # Begin call frame + begin_call_frame(tracker) + + # Track balance change that will be reverted + track_balance_change(tracker, address, U256(1000)) + + # Rollback the call frame + rollback_call_frame(tracker) + + # Build and check the access list + block_access_list = build_block_access_list(builder) + + # Address should still be in access list but without balance changes + account_found = False + for account in block_access_list.account_changes: + if account.address == address: + account_found = True + assert len(account.balance_changes) == 0 + break + + assert account_found + + def test_nested_call_frames_with_partial_revert(self) -> None: + """Test nested call frames where inner frame reverts but outer succeeds.""" + from ethereum.forks.amsterdam.block_access_lists.tracker import ( + begin_call_frame, + commit_call_frame, + rollback_call_frame, + ) + + builder = BlockAccessListBuilder() + tracker = StateChangeTracker(builder) + set_block_access_index(tracker, Uint(1)) + + address1 = Bytes20(b"\x01" * 20) + address2 = Bytes20(b"\x02" * 20) + + # Outer call frame + begin_call_frame(tracker) + track_balance_change(tracker, address1, U256(900)) + + # Inner call frame (will be reverted) + begin_call_frame(tracker) + track_balance_change(tracker, address2, U256(100)) + + # Rollback inner frame + rollback_call_frame(tracker) + + # Commit outer frame + commit_call_frame(tracker) + + # Build and check the access list + block_access_list = build_block_access_list(builder) + + # address1 should have balance change, address2 should not + address1_found = False + address2_found = False + + for account in block_access_list.account_changes: + if account.address == address1: + address1_found = True + assert len(account.balance_changes) == 1 + assert account.balance_changes[0].post_balance == U256(900) + elif account.address == address2: + address2_found = True + assert len(account.balance_changes) == 0 + + assert address1_found + assert address2_found # Address2 touched but no changes + + if __name__ == "__main__": pytest.main([__file__, "-v"])