Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/ethereum/forks/osaka/fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,12 @@ def check_transaction(

if Uint(sender_account.balance) < max_gas_fee + Uint(tx.value):
raise InsufficientBalanceError("insufficient sender balance")
if sender_account.code and not is_valid_delegation(sender_account.code):
raise InvalidSenderError("not EOA")
if sender_account.code:
track_bytecode_access(

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

block_env.state, sender_account.code, sender_address
)
if not is_valid_delegation(sender_account.code):
raise InvalidSenderError("not EOA")

return (
sender_address,
Expand Down Expand Up @@ -1063,7 +1067,8 @@ def increase_recipient_balance(recipient: Account) -> None:
rlp.encode(wd),
)

modify_state(block_env.state, wd.address, increase_recipient_balance)
if wd.amount != 0:

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the withdrawal is for 0 ETH, then we avoid this call to avoid bloating the witness. The stateles validator can skip this withdrawal if won't modify the state, thus the witness data isn't required.

modify_state(block_env.state, wd.address, increase_recipient_balance)


def check_gas_limit(gas_limit: Uint, parent_gas_limit: Uint) -> bool:
Expand Down
64 changes: 55 additions & 9 deletions src/ethereum/forks/osaka/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ def destroy_account(state: State, address: Address) -> None:
set_account(state, address, None)


def track_bytecode_access(state: State, code: Bytes) -> None:
def track_bytecode_access(
state: State, code: Bytes, address: Optional[Address] = None
) -> None:
"""
Track bytecode access for execution witness generation.

Expand All @@ -310,16 +312,28 @@ def track_bytecode_access(state: State, code: Bytes) -> None:
The state with optional witness tracking.
code : Bytes
The bytecode being accessed.
address : Optional[Address]
Account address whose code is being accessed. When provided,
only pre-block code for this address is included in witness
bytecodes.

"""
# Skip if witness mode disabled or empty bytecode (EOAs)
if state._witness_state is None or len(code) == 0:
return

ws = state._witness_state

# Exclude bytecode that was created or changed during the current block.
if address is not None:
pre_account = ws.pre_state_accounts.get(address)
if pre_account is None or pre_account.code != code:
return

# Compute hash and store for deduplication
code_hash = Bytes32(keccak256(code))
if code_hash not in state._witness_state.accessed_bytecodes:
state._witness_state.accessed_bytecodes[code_hash] = code
if code_hash not in ws.accessed_bytecodes:
ws.accessed_bytecodes[code_hash] = code


def track_block_hash_access(state: State, block_number: Uint) -> None:
Expand Down Expand Up @@ -878,6 +892,28 @@ def is_witness_mode_enabled(state: State) -> bool:
return state._witness_state is not None


def _debug_print_witness_state(ws: WitnessState) -> None:

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a helper that I needed to debug all things -- we can remove it eventually.

"""Print dirty/accessed accounts and storage in hex format."""
_hex = lambda b: "0x" + b.hex()
_hex_set = lambda s: (
"{\n"
+ "".join(f" {_hex(x)}\n" for x in sorted(s, key=lambda b: b.hex()))
+ " }"
)
_hex_dict_set = lambda d: (
"{\n"
+ "".join(
f" {_hex(k)}: {_hex_set(v)}\n"
for k, v in sorted(d.items(), key=lambda kv: kv[0].hex())
)
+ " }"
)
print("dirty_accounts:", _hex_set(ws.dirty_accounts))
print("dirty_storage:", _hex_dict_set(ws.dirty_storage))
print("accessed_accounts:", _hex_set(ws.accessed_accounts))
print("accessed_storage:", _hex_dict_set(ws.accessed_storage))


def _build_witness_mpts(state: State) -> None:
"""
Build and cache the IncrementalMPTs for witness generation.
Expand All @@ -899,6 +935,8 @@ def _build_witness_mpts(state: State) -> None:
if ws._main_mpt is not None:
return

_debug_print_witness_state(ws)

# Build pre-block storage MPTs
storage_mpts: Dict[Address, IncrementalMPT[Bytes32, U256]] = {}
for address, data in ws.pre_state_storages.items():
Expand All @@ -918,12 +956,19 @@ def get_pre_storage_root(address: Address) -> Root:
get_storage_root=get_pre_storage_root,
)

# 1. Do read-only storages accesses
for address, accessed_keys in ws.accessed_storage.items():
# 1. Traverse all accessed and dirty storage keys on the pre-state
# MPTs to capture pre-state trie nodes in the witness. This must
# happen before any writes since writes mutate the tree in-place.
all_storage_reads: Dict[Address, Set[Bytes32]] = {}
for address, keys in ws.accessed_storage.items():
all_storage_reads.setdefault(address, set()).update(keys)
for address, keys in ws.dirty_storage.items():
all_storage_reads.setdefault(address, set()).update(keys)
Comment on lines +959 to +966

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the first bullet I mentioned in the PR description.

Any pre-state MPT proof should be captured before any state tree change. If we don't do this, we might capture "transient MPT nodes" that we edited while post-state root calculation.


for address, keys in all_storage_reads.items():
if address not in storage_mpts:
continue

for key in accessed_keys:
for key in keys:
mpt_get(storage_mpts[address], key)

# 2. Apply dirty storage to storages (writes)
Expand Down Expand Up @@ -954,8 +999,9 @@ def get_pre_storage_root(address: Address) -> Root:
# - Storage changed (storage root changed) - tracked in dirty_storage
all_dirty_accounts = ws.dirty_accounts | set(ws.dirty_storage.keys())

# 3. Traverse accounts that were read
for address in ws.accessed_accounts:
# 3. Traverse all accessed and dirty accounts on the pre-state MPT
# to capture pre-state trie nodes before writes mutate the tree.
for address in ws.accessed_accounts | all_dirty_accounts:
mpt_get(main_mpt, address)

# 4. Apply dirty accounts
Expand Down
65 changes: 50 additions & 15 deletions src/ethereum/forks/osaka/trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class MutableLeafNode:
value: Bytes
_hash: Optional[Bytes] = None # Cached hash, invalidated on change
_rlp: Optional[Bytes] = None # Cached RLP encoding
_dirty: bool = False # True if created during execution (not pre-state)

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the PR description, this is a new boolena marker for all types of nodes in the MPT tree.

This is required, since whenever we capture potential sibilings in branch compressions we must be sure they existed in the pre-state tree. If they were edited or created during post-state root calculation, they must not be in the witness.



@dataclass
Expand All @@ -156,6 +157,7 @@ class MutableExtensionNode:
child: "MutableNode"
_hash: Optional[Bytes] = None
_rlp: Optional[Bytes] = None
_dirty: bool = False # True if created during execution (not pre-state)


@dataclass
Expand All @@ -166,6 +168,7 @@ class MutableBranchNode:
value: Bytes # Value if key terminates at this branch
_hash: Optional[Bytes] = None
_rlp: Optional[Bytes] = None
_dirty: bool = False # True if created during execution (not pre-state)


MutableNode = Union[
Expand Down Expand Up @@ -722,6 +725,10 @@ def _record_witness(
if node is None:
return

# Skip nodes created during execution (not in pre-state)
if node._dirty:
return

# Record the key if provided
if key is not None:
witness.accessed_keys.add(key)
Expand Down Expand Up @@ -950,11 +957,11 @@ def _mpt_insert_node(

Returns the new/updated node for this position.
"""
_record_witness(mpt.witness, node)

if node is None:
# Empty slot - create new leaf
return MutableLeafNode(rest_of_key=key[level:], value=value)
return MutableLeafNode(
rest_of_key=key[level:], value=value, _dirty=True
)

_invalidate_hash(node)

Expand Down Expand Up @@ -982,6 +989,7 @@ def _insert_into_leaf(
if existing_key == remaining_key:
# Same key - update value
node.value = value
node._dirty = True
return node

# Keys differ - need to create branch
Expand All @@ -997,7 +1005,9 @@ def _insert_into_leaf(
value,
)
return MutableExtensionNode(
key_segment=existing_key[:prefix_len], child=branch
key_segment=existing_key[:prefix_len],
child=branch,
_dirty=True,
)
else:
# No common prefix - create branch directly
Expand All @@ -1017,15 +1027,21 @@ def _create_branch_from_two_leaves(
branch_value = value1
else:
idx1 = key1[0]
children[idx1] = MutableLeafNode(rest_of_key=key1[1:], value=value1)
children[idx1] = MutableLeafNode(
rest_of_key=key1[1:], value=value1, _dirty=True
)

if len(key2) == 0:
branch_value = value2
else:
idx2 = key2[0]
children[idx2] = MutableLeafNode(rest_of_key=key2[1:], value=value2)
children[idx2] = MutableLeafNode(
rest_of_key=key2[1:], value=value2, _dirty=True
)

return MutableBranchNode(children=children, value=branch_value)
return MutableBranchNode(
children=children, value=branch_value, _dirty=True
)


def _insert_into_extension(
Expand All @@ -1045,14 +1061,17 @@ def _insert_into_extension(
node.child = _mpt_insert_node(
mpt, node.child, key, value, level + Uint(prefix_len)
)
node._dirty = True
return node

# Extension needs to be split
if prefix_len > 0:
# Partial match - create new extension for common prefix
new_child = _split_extension(node, remaining_key, value, prefix_len)
return MutableExtensionNode(
key_segment=segment[:prefix_len], child=new_child
key_segment=segment[:prefix_len],
child=new_child,
_dirty=True,
)
else:
# No common prefix - create branch at this level
Expand Down Expand Up @@ -1080,7 +1099,9 @@ def _split_extension(
# Multiple nibbles - create new extension
idx = segment_after_prefix[0]
children[idx] = MutableExtensionNode(
key_segment=segment_after_prefix[1:], child=node.child
key_segment=segment_after_prefix[1:],
child=node.child,
_dirty=True,
)

# Place new value
Expand All @@ -1091,13 +1112,17 @@ def _split_extension(
idx = key_after_prefix[0]
if children[idx] is None:
children[idx] = MutableLeafNode(
rest_of_key=key_after_prefix[1:], value=value
rest_of_key=key_after_prefix[1:],
value=value,
_dirty=True,
)
else:
# Need to merge with existing child (shouldn't happen normally)
raise AssertionError("Unexpected collision during split")

return MutableBranchNode(children=children, value=branch_value)
return MutableBranchNode(
children=children, value=branch_value, _dirty=True
)


def _insert_into_branch(
Expand All @@ -1113,13 +1138,15 @@ def _insert_into_branch(
if len(remaining_key) == 0:
# Value terminates at this branch
node.value = value
node._dirty = True
return node

# Recurse into appropriate child
child_idx = remaining_key[0]
node.children[child_idx] = _mpt_insert_node(
mpt, node.children[child_idx], key, value, level + Uint(1)
)
node._dirty = True
return node


Expand All @@ -1134,8 +1161,6 @@ def _mpt_delete_node(

Returns the updated node (may be different type or None).
"""
_record_witness(mpt.witness, node)

if node is None:
return None

Expand Down Expand Up @@ -1180,15 +1205,18 @@ def _delete_from_extension(
return MutableExtensionNode(
key_segment=segment + new_child.key_segment,
child=new_child.child,
_dirty=True,
)
elif isinstance(new_child, MutableLeafNode):
# Merge extension into leaf
return MutableLeafNode(
rest_of_key=segment + new_child.rest_of_key,
value=new_child.value,
_dirty=True,
)

node.child = new_child
node._dirty = True
return node


Expand All @@ -1211,6 +1239,7 @@ def _delete_from_branch(
mpt, node.children[child_idx], key, level + Uint(1)
)

node._dirty = True
# Check if branch can be collapsed
return _collapse_branch(mpt, node)

Expand All @@ -1233,19 +1262,25 @@ def _collapse_branch(
return MutableLeafNode(
rest_of_key=nibble + child.rest_of_key,
value=child.value,
_dirty=True,
)
elif isinstance(child, MutableExtensionNode):
return MutableExtensionNode(
key_segment=nibble + child.key_segment,
child=child.child,
_dirty=True,
)
else:
# Child is a branch - create extension
return MutableExtensionNode(key_segment=nibble, child=child)
return MutableExtensionNode(
key_segment=nibble, child=child, _dirty=True
)

if len(non_empty) == 0 and node.value != b"":
# Only value at this branch - convert to leaf
return MutableLeafNode(rest_of_key=b"", value=node.value)
return MutableLeafNode(
rest_of_key=b"", value=node.value, _dirty=True
)

return node

Expand Down
2 changes: 1 addition & 1 deletion src/ethereum/forks/osaka/utils/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def prepare_message(
current_target = tx.to
msg_data = tx.data
code = get_account(block_env.state, tx.to).code
track_bytecode_access(block_env.state, code)
track_bytecode_access(block_env.state, code, tx.to)

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

track_bytecode_access now checks if the provided address in the new parameter existed in the pre-state of the block. If it doesn't exist, this bytecode won't be in the witness since the bytecode was generated during block execution.

Probably with the official state tracker we might approach it differently -- just explaining the code change here.

code_address = tx.to
else:
raise AssertionError("Target must be address or empty bytes")
Expand Down
Loading
Loading