diff --git a/hathor/nanocontracts/exception.py b/hathor/nanocontracts/exception.py index ac23763a4..5cb45cb01 100644 --- a/hathor/nanocontracts/exception.py +++ b/hathor/nanocontracts/exception.py @@ -137,6 +137,11 @@ class NCForbiddenAction(NCFail): pass +class NCForbiddenReentrancy(NCFail): + """Raised when a reentrancy is forbidden on a method.""" + pass + + class UnknownFieldType(NCError): """Raised when there is no field available for a given type.""" pass diff --git a/hathor/nanocontracts/runner/runner.py b/hathor/nanocontracts/runner/runner.py index ebccae7a5..fc8699be4 100644 --- a/hathor/nanocontracts/runner/runner.py +++ b/hathor/nanocontracts/runner/runner.py @@ -28,6 +28,7 @@ NCAlreadyInitializedContractError, NCFail, NCForbiddenAction, + NCForbiddenReentrancy, NCInvalidContext, NCInvalidContractId, NCInvalidInitializeMethodCall, @@ -55,6 +56,7 @@ from hathor.nanocontracts.storage import NCBlockStorage, NCChangesTracker, NCContractStorage, NCStorageFactory from hathor.nanocontracts.storage.contract_storage import Balance from hathor.nanocontracts.types import ( + NC_ALLOW_REENTRANCY, NC_ALLOWED_ACTIONS_ATTR, NC_FALLBACK_METHOD, NC_INITIALIZE_METHOD, @@ -371,15 +373,18 @@ def syscall_proxy_call_public_method_nc_args( method_name=method_name, actions=actions, nc_args=nc_args, + skip_reentrancy_validation=True, ) def _unsafe_call_another_contract_public_method( self, + *, contract_id: ContractId, blueprint_id: BlueprintId, method_name: str, actions: Sequence[NCAction], nc_args: NCArgs, + skip_reentrancy_validation: bool = False, ) -> Any: """Invoke another contract's public method without running the usual guard‑safety checks. @@ -419,6 +424,7 @@ def _unsafe_call_another_contract_public_method( method_name=method_name, ctx=ctx, nc_args=nc_args, + skip_reentrancy_validation=skip_reentrancy_validation, ) def _reset_all_change_trackers(self) -> None: @@ -527,6 +533,7 @@ def _execute_public_method_call( method_name: str, ctx: Context, nc_args: NCArgs, + skip_reentrancy_validation: bool = False, ) -> Any: """An internal method that actually execute the public method call. It is also used when a contract calls another contract. @@ -558,6 +565,9 @@ def _execute_public_method_call( parser = Method.from_callable(method) args = self._validate_nc_args_for_method(parser, nc_args) + if not skip_reentrancy_validation: + self._validate_reentrancy(contract_id, called_method_name, method) + call_record = CallRecord( type=CallType.PUBLIC, depth=self._call_info.depth, @@ -855,11 +865,11 @@ def syscall_create_another_contract( self._internal_create_contract(child_id, blueprint_id) nc_args = NCParsedArgs(args, kwargs) ret = self._unsafe_call_another_contract_public_method( - child_id, - blueprint_id, - NC_INITIALIZE_METHOD, - actions, - nc_args, + contract_id=child_id, + blueprint_id=blueprint_id, + method_name=NC_INITIALIZE_METHOD, + actions=actions, + nc_args=nc_args, ) assert last_call_record.index_updates is not None @@ -973,6 +983,17 @@ def _validate_context(self, ctx: Context) -> None: if isinstance(action, BaseTokenAction) and action.amount < 0: raise NCInvalidContext('amount must be positive') + def _validate_reentrancy(self, contract_id: ContractId, method_name: str, method: Any) -> None: + """Check whether a reentrancy is happening and whether it is allowed.""" + assert self._call_info is not None + allow_reentrancy = getattr(method, NC_ALLOW_REENTRANCY, False) + if allow_reentrancy: + return + + for call_record in self._call_info.stack: + if call_record.contract_id == contract_id: + raise NCForbiddenReentrancy(f'reentrancy is forbidden on method `{method_name}`') + def _validate_actions(self, method: Any, method_name: str, ctx: Context) -> None: """Check whether actions are allowed.""" allowed_actions: set[NCActionType] = getattr(method, NC_ALLOWED_ACTIONS_ATTR, set()) diff --git a/hathor/nanocontracts/storage/changes_tracker.py b/hathor/nanocontracts/storage/changes_tracker.py index f4353a35e..f902b6026 100644 --- a/hathor/nanocontracts/storage/changes_tracker.py +++ b/hathor/nanocontracts/storage/changes_tracker.py @@ -190,11 +190,6 @@ def commit(self) -> None: self.has_been_commited = True - def reset(self) -> None: - """Discard all local changes without persisting.""" - self.data = {} - self._balance_diff = {} - @override def _get_mutable_balance(self, token_uid: bytes) -> MutableBalance: internal_key = BalanceKey(self.nc_id, token_uid) diff --git a/hathor/nanocontracts/types.py b/hathor/nanocontracts/types.py index 115c68c5f..16a64b399 100644 --- a/hathor/nanocontracts/types.py +++ b/hathor/nanocontracts/types.py @@ -58,6 +58,7 @@ class ContractId(VertexId): NC_FALLBACK_METHOD: str = 'fallback' NC_ALLOWED_ACTIONS_ATTR = '__nc_allowed_actions' +NC_ALLOW_REENTRANCY = '__nc_allow_reentrancy' NC_METHOD_TYPE_ATTR: str = '__nc_method_type' @@ -163,6 +164,7 @@ def _create_decorator_with_allowed_actions( allow_grant_authority: bool | None, allow_acquire_authority: bool | None, allow_actions: list[NCActionType] | None, + allow_reentrancy: bool, ) -> Callable: """Internal utility to create a decorator that sets allowed actions.""" flags = { @@ -179,6 +181,7 @@ def decorator(fn: Callable) -> Callable: allowed_actions = set(allow_actions) if allow_actions else set() allowed_actions.update(action for action, flag in flags.items() if flag) setattr(fn, NC_ALLOWED_ACTIONS_ATTR, allowed_actions) + setattr(fn, NC_ALLOW_REENTRANCY, allow_reentrancy) decorator_body(fn) return fn @@ -197,6 +200,7 @@ def public( allow_grant_authority: bool | None = None, allow_acquire_authority: bool | None = None, allow_actions: list[NCActionType] | None = None, + allow_reentrancy: bool = False, ) -> Callable: """Decorator to mark a blueprint method as public.""" def decorator(fn: Callable) -> None: @@ -219,6 +223,7 @@ def decorator(fn: Callable) -> None: allow_grant_authority=allow_grant_authority, allow_acquire_authority=allow_acquire_authority, allow_actions=allow_actions, + allow_reentrancy=allow_reentrancy, ) @@ -246,6 +251,7 @@ def fallback( allow_grant_authority: bool | None = None, allow_acquire_authority: bool | None = None, allow_actions: list[NCActionType] | None = None, + allow_reentrancy: bool = False, ) -> Callable: """Decorator to mark a blueprint method as fallback. The method must also be called `fallback`.""" def decorator(fn: Callable) -> None: @@ -279,6 +285,7 @@ def decorator(fn: Callable) -> None: allow_grant_authority=allow_grant_authority, allow_acquire_authority=allow_acquire_authority, allow_actions=allow_actions, + allow_reentrancy=allow_reentrancy, ) diff --git a/tests/nanocontracts/test_authorities_call_another.py b/tests/nanocontracts/test_authorities_call_another.py index 27525993a..7b4053f1a 100644 --- a/tests/nanocontracts/test_authorities_call_another.py +++ b/tests/nanocontracts/test_authorities_call_another.py @@ -51,7 +51,7 @@ class CallerBlueprint(Blueprint): def initialize(self, ctx: Context, other_id: ContractId) -> None: self.other_id = other_id - @public(allow_grant_authority=True) + @public(allow_grant_authority=True, allow_reentrancy=True) def nop(self, ctx: Context) -> None: pass @@ -60,7 +60,7 @@ def grant_to_other(self, ctx: Context, token_uid: TokenUid, mint: bool, melt: bo action = NCGrantAuthorityAction(token_uid=token_uid, mint=mint, melt=melt) self.syscall.call_public_method(self.other_id, 'nop', [action]) - @public(allow_grant_authority=True) + @public(allow_grant_authority=True, allow_reentrancy=True) def revoke_from_self(self, ctx: Context, token_uid: TokenUid, mint: bool, melt: bool) -> None: self.syscall.revoke_authorities(token_uid, revoke_mint=mint, revoke_melt=melt) diff --git a/tests/nanocontracts/test_call_other_contract.py b/tests/nanocontracts/test_call_other_contract.py index c01cb1b15..7aa3569f6 100644 --- a/tests/nanocontracts/test_call_other_contract.py +++ b/tests/nanocontracts/test_call_other_contract.py @@ -80,7 +80,7 @@ def get_tokens_from_another_contract(self, ctx: Context) -> None: if actions: self.syscall.call_public_method(self.contract, 'get_tokens_from_another_contract', actions) - @public + @public(allow_reentrancy=True) def dec(self, ctx: Context, fail_on_zero: bool) -> None: if self.counter == 0: if fail_on_zero: diff --git a/tests/nanocontracts/test_contract_upgrade.py b/tests/nanocontracts/test_contract_upgrade.py index c06b43f88..7747a5dfd 100644 --- a/tests/nanocontracts/test_contract_upgrade.py +++ b/tests/nanocontracts/test_contract_upgrade.py @@ -88,7 +88,7 @@ def initialize(self, ctx: Context) -> None: def inc(self, ctx: Context) -> None: self.counter += 3 - @public + @public(allow_reentrancy=True) def on_upgrade_inc(self, ctx: Context) -> None: self.counter += 100 diff --git a/tests/nanocontracts/test_execution_order.py b/tests/nanocontracts/test_execution_order.py index 0252d0429..c56a9c9fc 100644 --- a/tests/nanocontracts/test_execution_order.py +++ b/tests/nanocontracts/test_execution_order.py @@ -96,7 +96,7 @@ def accept_deposit_from_another(self, ctx: Context, contract_id: ContractId) -> self.syscall.call_public_method(contract_id, 'accept_deposit_from_another_callback', [action]) self.assert_token_balance(before=0, current=4) - @public(allow_deposit=True) + @public(allow_deposit=True, allow_reentrancy=True) def accept_deposit_from_another_callback(self, ctx: Context) -> None: self.assert_token_balance(before=3, current=6) @@ -116,7 +116,7 @@ def accept_withdrawal_from_another(self, ctx: Context, contract_id: ContractId) self.syscall.call_public_method(contract_id, 'accept_withdrawal_from_another_callback', [action]) self.assert_token_balance(before=4, current=3) - @public(allow_withdrawal=True) + @public(allow_withdrawal=True, allow_reentrancy=True) def accept_withdrawal_from_another_callback(self, ctx: Context) -> None: self.assert_token_balance(before=7, current=6) diff --git a/tests/nanocontracts/test_reentrancy.py b/tests/nanocontracts/test_reentrancy.py index 434f23b59..90194813c 100644 --- a/tests/nanocontracts/test_reentrancy.py +++ b/tests/nanocontracts/test_reentrancy.py @@ -1,5 +1,6 @@ from hathor.nanocontracts import Blueprint, Context, NCFail, public -from hathor.nanocontracts.types import Amount, ContractId, NCAction, NCDepositAction, TokenUid +from hathor.nanocontracts.exception import NCForbiddenReentrancy +from hathor.nanocontracts.types import Amount, CallerId, ContractId, NCAction, NCDepositAction, TokenUid from tests.nanocontracts.blueprints.unittest import BlueprintTestCase HTR_TOKEN_UID = TokenUid(b'\0') @@ -10,10 +11,8 @@ class InsufficientBalance(NCFail): class MyBlueprint(Blueprint): - # I used dict[bytes, int] for two reasons: - # 1. `bytes` works for both Address and ContractId - # 2. int allows negative values - balances: dict[bytes, int] + # I used dict[CallerId, int] because int allows negative values. + balances: dict[CallerId, int] @public def initialize(self, ctx: Context) -> None: @@ -31,7 +30,7 @@ def deposit(self, ctx: Context) -> None: else: self.balances[address] += amount - @public + @public(allow_reentrancy=True) def transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method: str) -> None: address = ctx.caller_id if amount > self.balances.get(address, 0): @@ -43,7 +42,7 @@ def transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method self.syscall.call_public_method(contract, method, actions=actions) self.balances[address] -= amount - @public + @public(allow_reentrancy=True) def fixed_transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method: str) -> None: address = ctx.caller_id if amount > self.balances.get(address, 0): @@ -55,6 +54,16 @@ def fixed_transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, self.balances[address] -= amount self.syscall.call_public_method(contract, method, actions=actions) + @public + def protected_transfer_to(self, ctx: Context, amount: Amount, contract: ContractId, method: str) -> None: + address = ctx.caller_id + if amount > self.balances.get(address, 0): + raise InsufficientBalance('insufficient balance') + + actions: list[NCAction] = [NCDepositAction(token_uid=HTR_TOKEN_UID, amount=amount)] + self.syscall.call_public_method(contract, method, actions=actions) + self.balances[address] -= amount + class AttackerBlueprint(Blueprint): target: ContractId @@ -79,15 +88,19 @@ def initialize(self, ctx: Context, target: ContractId, n_calls: int) -> None: def nop(self, ctx: Context) -> None: pass - @public(allow_deposit=True) + @public(allow_deposit=True, allow_reentrancy=True) def attack(self, ctx: Context) -> None: - self._run_attack('transfer_to') + self._run_attack('transfer_to', 'attack') - @public(allow_deposit=True) - def attack_fail(self, ctx: Context) -> None: - self._run_attack('fixed_transfer_to') + @public(allow_deposit=True, allow_reentrancy=True) + def attack_fixed(self, ctx: Context) -> None: + self._run_attack('fixed_transfer_to', 'attack_fixed') + + @public(allow_deposit=True, allow_reentrancy=True) + def attack_protected(self, ctx: Context) -> None: + self._run_attack('protected_transfer_to', 'attack_protected') - def _run_attack(self, method: str) -> None: + def _run_attack(self, method: str, callback: str) -> None: if self.counter >= self.n_calls: return @@ -98,7 +111,7 @@ def _run_attack(self, method: str) -> None: actions=[], amount=self.amount, contract=self.syscall.get_contract_id(), - method='attack', + method=callback, ) @@ -197,7 +210,7 @@ def test_attack_succeed(self) -> None: assert self.target_storage.get_balance(HTR_TOKEN_UID).value == 10_150 - self.n_calls * 50 assert self.attacker_storage.get_balance(HTR_TOKEN_UID).value == self.n_calls * 50 - def test_attack_fail(self) -> None: + def test_attack_fail_fixed(self) -> None: tx = self.get_genesis_tx() # Attacker contract has a balance of 0.50 HTR in the target contract. @@ -206,7 +219,23 @@ def test_attack_fail(self) -> None: ctx = Context([], tx, self.address1, timestamp=0) self.runner.call_public_method( self.nc_attacker_id, - 'attack_fail', + 'attack_fixed', + ctx, + ) + + assert self.target_storage.get_balance(HTR_TOKEN_UID).value == 10_150 + assert self.attacker_storage.get_balance(HTR_TOKEN_UID).value == 0 + + def test_attack_fail_protected(self) -> None: + tx = self.get_genesis_tx() + + # Attacker contract has a balance of 0.50 HTR in the target contract. + # It tries to extract more than 0.50 HTR and fails. + with self.assertRaises(NCForbiddenReentrancy): + ctx = Context([], tx, self.address1, timestamp=0) + self.runner.call_public_method( + self.nc_attacker_id, + 'attack_protected', ctx, )