Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions hathor/nanocontracts/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ class NCForbiddenAction(NCFail):
pass


class NCForbiddenReentrancy(NCFail):
"""Raised when a reentrancy is forbidden on a method."""
pass


class UnknownFieldType(NCError):
"""Raised when there is no field available for a given type."""
pass
Expand Down
31 changes: 26 additions & 5 deletions hathor/nanocontracts/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
NCAlreadyInitializedContractError,
NCFail,
NCForbiddenAction,
NCForbiddenReentrancy,
NCInvalidContext,
NCInvalidContractId,
NCInvalidInitializeMethodCall,
Expand Down Expand Up @@ -55,6 +56,7 @@
from hathor.nanocontracts.storage import NCBlockStorage, NCChangesTracker, NCContractStorage, NCStorageFactory
from hathor.nanocontracts.storage.contract_storage import Balance
from hathor.nanocontracts.types import (
NC_ALLOW_REENTRANCY,
NC_ALLOWED_ACTIONS_ATTR,
NC_FALLBACK_METHOD,
NC_INITIALIZE_METHOD,
Expand Down Expand Up @@ -371,15 +373,18 @@ def syscall_proxy_call_public_method_nc_args(
method_name=method_name,
actions=actions,
nc_args=nc_args,
skip_reentrancy_validation=True,
)

def _unsafe_call_another_contract_public_method(
self,
*,
contract_id: ContractId,
blueprint_id: BlueprintId,
method_name: str,
actions: Sequence[NCAction],
nc_args: NCArgs,
skip_reentrancy_validation: bool = False,
) -> Any:
"""Invoke another contract's public method without running the usual guard‑safety checks.

Expand Down Expand Up @@ -419,6 +424,7 @@ def _unsafe_call_another_contract_public_method(
method_name=method_name,
ctx=ctx,
nc_args=nc_args,
skip_reentrancy_validation=skip_reentrancy_validation,
)

def _reset_all_change_trackers(self) -> None:
Expand Down Expand Up @@ -527,6 +533,7 @@ def _execute_public_method_call(
method_name: str,
ctx: Context,
nc_args: NCArgs,
skip_reentrancy_validation: bool = False,
) -> Any:
"""An internal method that actually execute the public method call.
It is also used when a contract calls another contract.
Expand Down Expand Up @@ -558,6 +565,9 @@ def _execute_public_method_call(
parser = Method.from_callable(method)
args = self._validate_nc_args_for_method(parser, nc_args)

if not skip_reentrancy_validation:
self._validate_reentrancy(contract_id, called_method_name, method)

call_record = CallRecord(
type=CallType.PUBLIC,
depth=self._call_info.depth,
Expand Down Expand Up @@ -855,11 +865,11 @@ def syscall_create_another_contract(
self._internal_create_contract(child_id, blueprint_id)
nc_args = NCParsedArgs(args, kwargs)
ret = self._unsafe_call_another_contract_public_method(
child_id,
blueprint_id,
NC_INITIALIZE_METHOD,
actions,
nc_args,
contract_id=child_id,
blueprint_id=blueprint_id,
method_name=NC_INITIALIZE_METHOD,
actions=actions,
nc_args=nc_args,
)

assert last_call_record.index_updates is not None
Expand Down Expand Up @@ -973,6 +983,17 @@ def _validate_context(self, ctx: Context) -> None:
if isinstance(action, BaseTokenAction) and action.amount < 0:
raise NCInvalidContext('amount must be positive')

def _validate_reentrancy(self, contract_id: ContractId, method_name: str, method: Any) -> None:
"""Check whether a reentrancy is happening and whether it is allowed."""
assert self._call_info is not None
allow_reentrancy = getattr(method, NC_ALLOW_REENTRANCY, False)
if allow_reentrancy:
return

for call_record in self._call_info.stack:
Copy link
Member Author

@msbrogli msbrogli Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current code runs in O(n) time, where n is the maximum allowed depth (currently set to 100). We could optimize it to run in O(1) time by maintaining counters in a dict[tuple[ContractId, method_name], int]. Is this optimization worth implementing? Should we reduce the maximum allowed depth?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100 equality comparisons is too cheap, I don't think it's a problem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can get up to 1+2+3+...+100 = 100*101/2 = 5050 comparisons, actually. But I guess it's still not much and shouldn't be a problem. We can always optimize it later. Maybe just leave a note about it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansegre What do you think? This check runs in O(n) where n is the stack size. So, the overall algorithm runs in O(n^2) where n is the maximum stack size it reaches during execution. is it worth the effort to reduce the time complexity to O(n)? Any DoS attack vectors here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this can be optimized, but I don't think it should be a blocker. Do we have any idea about the order of magnitude of time it takes when n=100? Being proportional to ~10.000 times a single verification should be manageable I think, but certainly not ideal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, I don't think this is a blocker to merge this PR.

I think it's OK to add a task to optimize it later. I just think it's possible to avoid looking at the whole stack every time, maybe memoizing the last call on the same stack somehow so each check doesn't need to go up the stack again, only up to the last call.

if call_record.contract_id == contract_id:
raise NCForbiddenReentrancy(f'reentrancy is forbidden on method `{method_name}`')

def _validate_actions(self, method: Any, method_name: str, ctx: Context) -> None:
"""Check whether actions are allowed."""
allowed_actions: set[NCActionType] = getattr(method, NC_ALLOWED_ACTIONS_ATTR, set())
Expand Down
5 changes: 0 additions & 5 deletions hathor/nanocontracts/storage/changes_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,6 @@ def commit(self) -> None:

self.has_been_commited = True

def reset(self) -> None:
"""Discard all local changes without persisting."""
self.data = {}
self._balance_diff = {}

@override
def _get_mutable_balance(self, token_uid: bytes) -> MutableBalance:
internal_key = BalanceKey(self.nc_id, token_uid)
Expand Down
7 changes: 7 additions & 0 deletions hathor/nanocontracts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ContractId(VertexId):
NC_FALLBACK_METHOD: str = 'fallback'

NC_ALLOWED_ACTIONS_ATTR = '__nc_allowed_actions'
NC_ALLOW_REENTRANCY = '__nc_allow_reentrancy'
NC_METHOD_TYPE_ATTR: str = '__nc_method_type'


Expand Down Expand Up @@ -163,6 +164,7 @@ def _create_decorator_with_allowed_actions(
allow_grant_authority: bool | None,
allow_acquire_authority: bool | None,
allow_actions: list[NCActionType] | None,
allow_reentrancy: bool,
) -> Callable:
"""Internal utility to create a decorator that sets allowed actions."""
flags = {
Expand All @@ -179,6 +181,7 @@ def decorator(fn: Callable) -> Callable:
allowed_actions = set(allow_actions) if allow_actions else set()
allowed_actions.update(action for action, flag in flags.items() if flag)
setattr(fn, NC_ALLOWED_ACTIONS_ATTR, allowed_actions)
setattr(fn, NC_ALLOW_REENTRANCY, allow_reentrancy)

decorator_body(fn)
return fn
Expand All @@ -197,6 +200,7 @@ def public(
allow_grant_authority: bool | None = None,
allow_acquire_authority: bool | None = None,
allow_actions: list[NCActionType] | None = None,
allow_reentrancy: bool = False,
) -> Callable:
"""Decorator to mark a blueprint method as public."""
def decorator(fn: Callable) -> None:
Expand All @@ -219,6 +223,7 @@ def decorator(fn: Callable) -> None:
allow_grant_authority=allow_grant_authority,
allow_acquire_authority=allow_acquire_authority,
allow_actions=allow_actions,
allow_reentrancy=allow_reentrancy,
)


Expand Down Expand Up @@ -246,6 +251,7 @@ def fallback(
allow_grant_authority: bool | None = None,
allow_acquire_authority: bool | None = None,
allow_actions: list[NCActionType] | None = None,
allow_reentrancy: bool = False,
) -> Callable:
"""Decorator to mark a blueprint method as fallback. The method must also be called `fallback`."""
def decorator(fn: Callable) -> None:
Expand Down Expand Up @@ -279,6 +285,7 @@ def decorator(fn: Callable) -> None:
allow_grant_authority=allow_grant_authority,
allow_acquire_authority=allow_acquire_authority,
allow_actions=allow_actions,
allow_reentrancy=allow_reentrancy,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/nanocontracts/test_authorities_call_another.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CallerBlueprint(Blueprint):
def initialize(self, ctx: Context, other_id: ContractId) -> None:
self.other_id = other_id

@public(allow_grant_authority=True)
@public(allow_grant_authority=True, allow_reentrancy=True)
def nop(self, ctx: Context) -> None:
pass

Expand All @@ -60,7 +60,7 @@ def grant_to_other(self, ctx: Context, token_uid: TokenUid, mint: bool, melt: bo
action = NCGrantAuthorityAction(token_uid=token_uid, mint=mint, melt=melt)
self.syscall.call_public_method(self.other_id, 'nop', [action])

@public(allow_grant_authority=True)
@public(allow_grant_authority=True, allow_reentrancy=True)
def revoke_from_self(self, ctx: Context, token_uid: TokenUid, mint: bool, melt: bool) -> None:
self.syscall.revoke_authorities(token_uid, revoke_mint=mint, revoke_melt=melt)

Expand Down
2 changes: 1 addition & 1 deletion tests/nanocontracts/test_call_other_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_tokens_from_another_contract(self, ctx: Context) -> None:
if actions:
self.syscall.call_public_method(self.contract, 'get_tokens_from_another_contract', actions)

@public
@public(allow_reentrancy=True)
def dec(self, ctx: Context, fail_on_zero: bool) -> None:
if self.counter == 0:
if fail_on_zero:
Expand Down
2 changes: 1 addition & 1 deletion tests/nanocontracts/test_contract_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def initialize(self, ctx: Context) -> None:
def inc(self, ctx: Context) -> None:
self.counter += 3

@public
@public(allow_reentrancy=True)
def on_upgrade_inc(self, ctx: Context) -> None:
self.counter += 100

Expand Down
4 changes: 2 additions & 2 deletions tests/nanocontracts/test_execution_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def accept_deposit_from_another(self, ctx: Context, contract_id: ContractId) ->
self.syscall.call_public_method(contract_id, 'accept_deposit_from_another_callback', [action])
self.assert_token_balance(before=0, current=4)

@public(allow_deposit=True)
@public(allow_deposit=True, allow_reentrancy=True)
def accept_deposit_from_another_callback(self, ctx: Context) -> None:
self.assert_token_balance(before=3, current=6)

Expand All @@ -116,7 +116,7 @@ def accept_withdrawal_from_another(self, ctx: Context, contract_id: ContractId)
self.syscall.call_public_method(contract_id, 'accept_withdrawal_from_another_callback', [action])
self.assert_token_balance(before=4, current=3)

@public(allow_withdrawal=True)
@public(allow_withdrawal=True, allow_reentrancy=True)
def accept_withdrawal_from_another_callback(self, ctx: Context) -> None:
self.assert_token_balance(before=7, current=6)

Expand Down
61 changes: 45 additions & 16 deletions tests/nanocontracts/test_reentrancy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from hathor.nanocontracts import Blueprint, Context, NCFail, public
from hathor.nanocontracts.types import Amount, ContractId, NCAction, NCDepositAction, TokenUid
from hathor.nanocontracts.exception import NCForbiddenReentrancy
from hathor.nanocontracts.types import Amount, CallerId, ContractId, NCAction, NCDepositAction, TokenUid
from tests.nanocontracts.blueprints.unittest import BlueprintTestCase

HTR_TOKEN_UID = TokenUid(b'\0')
Expand All @@ -10,10 +11,8 @@ class InsufficientBalance(NCFail):


class MyBlueprint(Blueprint):
# I used dict[bytes, int] for two reasons:
# 1. `bytes` works for both Address and ContractId
# 2. int allows negative values
balances: dict[bytes, int]
# I used dict[CallerId, int] because int allows negative values.
balances: dict[CallerId, int]

@public
def initialize(self, ctx: Context) -> None:
Expand All @@ -31,7 +30,7 @@ def deposit(self, ctx: Context) -> None:
else:
self.balances[address] += amount

@public
@public(allow_reentrancy=True)
def transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method: str) -> None:
address = ctx.caller_id
if amount > self.balances.get(address, 0):
Expand All @@ -43,7 +42,7 @@ def transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method
self.syscall.call_public_method(contract, method, actions=actions)
self.balances[address] -= amount

@public
@public(allow_reentrancy=True)
def fixed_transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method: str) -> None:
address = ctx.caller_id
if amount > self.balances.get(address, 0):
Expand All @@ -55,6 +54,16 @@ def fixed_transfer_to(self, ctx: Context, amount: Amount, contract: ContractId,
self.balances[address] -= amount
self.syscall.call_public_method(contract, method, actions=actions)

@public
def protected_transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method: str) -> None:
address = ctx.caller_id
if amount > self.balances.get(address, 0):
raise InsufficientBalance('insufficient balance')

actions: list[NCAction] = [NCDepositAction(token_uid=HTR_TOKEN_UID, amount=amount)]
self.syscall.call_public_method(contract, method, actions=actions)
self.balances[address] -= amount


class AttackerBlueprint(Blueprint):
target: ContractId
Expand All @@ -79,15 +88,19 @@ def initialize(self, ctx: Context, target: ContractId, n_calls: int) -> None:
def nop(self, ctx: Context) -> None:
pass

@public(allow_deposit=True)
@public(allow_deposit=True, allow_reentrancy=True)
def attack(self, ctx: Context) -> None:
self._run_attack('transfer_to')
self._run_attack('transfer_to', 'attack')

@public(allow_deposit=True)
def attack_fail(self, ctx: Context) -> None:
self._run_attack('fixed_transfer_to')
@public(allow_deposit=True, allow_reentrancy=True)
def attack_fixed(self, ctx: Context) -> None:
self._run_attack('fixed_transfer_to', 'attack_fixed')

@public(allow_deposit=True, allow_reentrancy=True)
def attack_protected(self, ctx: Context) -> None:
self._run_attack('protected_transfer_to', 'attack_protected')

def _run_attack(self, method: str) -> None:
def _run_attack(self, method: str, callback: str) -> None:
if self.counter >= self.n_calls:
return

Expand All @@ -98,7 +111,7 @@ def _run_attack(self, method: str) -> None:
actions=[],
amount=self.amount,
contract=self.syscall.get_contract_id(),
method='attack',
method=callback,
)


Expand Down Expand Up @@ -197,7 +210,7 @@ def test_attack_succeed(self) -> None:
assert self.target_storage.get_balance(HTR_TOKEN_UID).value == 10_150 - self.n_calls * 50
assert self.attacker_storage.get_balance(HTR_TOKEN_UID).value == self.n_calls * 50

def test_attack_fail(self) -> None:
def test_attack_fail_fixed(self) -> None:
tx = self.get_genesis_tx()

# Attacker contract has a balance of 0.50 HTR in the target contract.
Expand All @@ -206,7 +219,23 @@ def test_attack_fail(self) -> None:
ctx = Context([], tx, self.address1, timestamp=0)
self.runner.call_public_method(
self.nc_attacker_id,
'attack_fail',
'attack_fixed',
ctx,
)

assert self.target_storage.get_balance(HTR_TOKEN_UID).value == 10_150
assert self.attacker_storage.get_balance(HTR_TOKEN_UID).value == 0

def test_attack_fail_protected(self) -> None:
tx = self.get_genesis_tx()

# Attacker contract has a balance of 0.50 HTR in the target contract.
# It tries to extract more than 0.50 HTR and fails.
with self.assertRaises(NCForbiddenReentrancy):
ctx = Context([], tx, self.address1, timestamp=0)
self.runner.call_public_method(
self.nc_attacker_id,
'attack_protected',
ctx,
)

Expand Down
Loading