diff --git a/hathor/nanocontracts/blueprint_env.py b/hathor/nanocontracts/blueprint_env.py index 3c41b0c41..5fb4bde5d 100644 --- a/hathor/nanocontracts/blueprint_env.py +++ b/hathor/nanocontracts/blueprint_env.py @@ -58,14 +58,28 @@ def rng(self) -> NanoRNG: return self.__runner.syscall_get_rng() def get_contract_id(self) -> ContractId: - """Return the contract id of this nano contract.""" + """Return the ContractId of the current nano contract.""" return self.__runner.get_current_contract_id() def get_blueprint_id(self) -> BlueprintId: - """Return the blueprint id of this nano contract.""" + """ + Return the BlueprintId of the current nano contract. + + This means that during a proxy call, this method will return the BlueprintId of the caller's blueprint, + NOT the BlueprintId of the Blueprint that owns the running code. + """ contract_id = self.get_contract_id() return self.__runner.get_blueprint_id(contract_id) + def get_current_code_blueprint_id(self) -> BlueprintId: + """ + Return the BlueprintId of the Blueprint that owns the currently running code. + + This means that during a proxy call, this method will return the BlueprintId of the Blueprint that owns the + running code, NOT the BlueprintId of the current nano contract. + """ + return self.__runner.get_current_code_blueprint_id() + def get_balance_before_current_call(self, token_uid: TokenUid | None = None) -> Amount: """ Return the balance for a given token before the current call, that is, diff --git a/hathor/nanocontracts/runner/runner.py b/hathor/nanocontracts/runner/runner.py index dcdcea100..7106a1d19 100644 --- a/hathor/nanocontracts/runner/runner.py +++ b/hathor/nanocontracts/runner/runner.py @@ -229,6 +229,11 @@ def get_blueprint_id(self, contract_id: ContractId) -> BlueprintId: nc_storage = self.get_current_changes_tracker_or_storage(contract_id) return nc_storage.get_blueprint_id() + def get_current_code_blueprint_id(self) -> BlueprintId: + """Return the blueprint id of the blueprint that owns the executing code.""" + current_call_record = self.get_current_call_record() + return current_call_record.blueprint_id + def _build_call_info(self, contract_id: ContractId) -> CallInfo: from hathor.nanocontracts.nc_exec_logs import NCLogger return CallInfo( @@ -374,9 +379,11 @@ def syscall_proxy_call_public_method( raise NCInvalidInitializeMethodCall('cannot call initialize from another contract') contract_id = self.get_current_contract_id() - if blueprint_id == self.get_blueprint_id(contract_id): - raise NCInvalidSyscall('cannot call the same blueprint') + raise NCInvalidSyscall('cannot call the same blueprint of the running contract') + + if blueprint_id == self.get_current_code_blueprint_id(): + raise NCInvalidSyscall('cannot call the same blueprint of the running blueprint') return self._unsafe_call_another_contract_public_method( contract_id=contract_id, diff --git a/tests/nanocontracts/test_proxy_accessor.py b/tests/nanocontracts/test_proxy_accessor.py index 3fcffb5dd..1f1036b37 100644 --- a/tests/nanocontracts/test_proxy_accessor.py +++ b/tests/nanocontracts/test_proxy_accessor.py @@ -28,6 +28,7 @@ fallback, public, ) +from hathor.nanocontracts.exception import NCInvalidSyscall from tests.nanocontracts.blueprints.unittest import BlueprintTestCase @@ -91,6 +92,34 @@ def test_fallback_forbidden(self, ctx: Context) -> str: proxy = self.syscall.get_proxy(self.other_blueprint_id) return proxy.public(forbid_fallback=True).unknown() + @public + def test_get_blueprint_id_through_proxy(self, ctx: Context) -> BlueprintId: + proxy = self.syscall.get_proxy(self.other_blueprint_id) + return proxy.public().get_blueprint_id() + + @public + def test_get_current_code_blueprint_id(self, ctx: Context) -> BlueprintId: + current_code_blueprint_id = self.syscall.get_current_code_blueprint_id() + assert self.syscall.get_blueprint_id() == current_code_blueprint_id, ( + "should be the same BlueprintId when we're not in a proxy call" + ) + proxy = self.syscall.get_proxy(self.other_blueprint_id) + return proxy.public().get_current_code_blueprint_id() + + @public + def nop(self, ctx: Context) -> None: + pass + + @public + def call_itself_through_double_proxy_other(self, ctx: Context) -> None: + proxy = self.syscall.get_proxy(self.other_blueprint_id) + proxy.public().call_itself_through_proxy(self.other_blueprint_id) + + @public + def call_itself_through_double_proxy_same(self, ctx: Context) -> None: + proxy = self.syscall.get_proxy(self.other_blueprint_id) + proxy.public().call_itself_through_proxy(self.syscall.get_blueprint_id()) + class MyBlueprint2(Blueprint): @public @@ -105,6 +134,23 @@ def hello(self, ctx: Context, name: str) -> str: def fallback(self, ctx: Context, method_name: str, nc_args: NCArgs) -> str: return f'fallback called for method `{method_name}`' + @public + def get_blueprint_id(self, ctx: Context) -> BlueprintId: + return self.syscall.get_blueprint_id() + + @public + def get_current_code_blueprint_id(self, ctx: Context) -> BlueprintId: + return self.syscall.get_current_code_blueprint_id() + + @public + def nop(self, ctx: Context) -> None: + pass + + @public + def call_itself_through_proxy(self, ctx: Context, blueprint_id: BlueprintId) -> None: + proxy = self.syscall.get_proxy(blueprint_id) + proxy.public().nop() + class TestProxyAccessor(BlueprintTestCase): def setUp(self) -> None: @@ -112,14 +158,16 @@ def setUp(self) -> None: self.blueprint_id1 = self._register_blueprint_class(MyBlueprint1) self.blueprint_id2 = self._register_blueprint_class(MyBlueprint2) - self.contract_id = self.gen_random_contract_id() + self.contract_id1 = self.gen_random_contract_id() + self.contract_id2 = self.gen_random_contract_id() ctx = self.create_context([NCDepositAction(amount=123, token_uid=HATHOR_TOKEN_UID)]) - self.runner.create_contract(self.contract_id, self.blueprint_id1, ctx, self.blueprint_id2) + self.runner.create_contract(self.contract_id1, self.blueprint_id1, ctx, self.blueprint_id2) + self.runner.create_contract(self.contract_id2, self.blueprint_id2, self.create_context()) def test_get_blueprint_id(self) -> None: ret = self.runner.call_public_method( - self.contract_id, + self.contract_id1, 'test_get_blueprint_id', self.create_context(), ) @@ -127,7 +175,7 @@ def test_get_blueprint_id(self) -> None: def test_public_method(self) -> None: ret = self.runner.call_public_method( - self.contract_id, + self.contract_id1, 'test_public_method', self.create_context(), 'alice', @@ -141,7 +189,7 @@ def test_multiple_public_calls_on_prepared_call(self) -> None: ) with pytest.raises(NCFail, match=re.escape(msg)): self.runner.call_public_method( - self.contract_id, + self.contract_id1, 'test_multiple_public_calls_on_prepared_call', self.create_context(), ) @@ -153,14 +201,14 @@ def test_multiple_public_calls_on_method(self) -> None: ) with pytest.raises(NCFail, match=re.escape(msg)): self.runner.call_public_method( - self.contract_id, + self.contract_id1, 'test_multiple_public_calls_on_method', self.create_context(), ) def test_fallback_allowed(self) -> None: ret = self.runner.call_public_method( - self.contract_id, + self.contract_id1, 'test_fallback_allowed', self.create_context(), ) @@ -170,7 +218,48 @@ def test_fallback_forbidden(self) -> None: msg = 'method `unknown` not found and fallback is forbidden' with pytest.raises(NCFail, match=re.escape(msg)): self.runner.call_public_method( - self.contract_id, + self.contract_id1, 'test_fallback_forbidden', self.create_context(), ) + + def test_get_blueprint_id_through_proxy(self) -> None: + ret = self.runner.call_public_method( + self.contract_id1, + 'test_get_blueprint_id_through_proxy', + self.create_context(), + ) + assert ret == self.blueprint_id1 + + def test_get_current_code_blueprint_id(self) -> None: + ret = self.runner.call_public_method( + self.contract_id1, + 'test_get_current_code_blueprint_id', + self.create_context(), + ) + assert ret == self.blueprint_id2 + + def test_call_itself_through_proxy(self) -> None: + with pytest.raises(NCInvalidSyscall, match='cannot call the same blueprint of the running contract'): + self.runner.call_public_method( + self.contract_id2, + 'call_itself_through_proxy', + self.create_context(), + self.blueprint_id2, + ) + + def test_call_itself_through_double_proxy_other(self) -> None: + with pytest.raises(NCInvalidSyscall, match='cannot call the same blueprint of the running blueprint'): + self.runner.call_public_method( + self.contract_id1, + 'call_itself_through_double_proxy_other', + self.create_context(), + ) + + def test_call_itself_through_double_proxy_same(self) -> None: + with pytest.raises(NCInvalidSyscall, match='cannot call the same blueprint of the running contract'): + self.runner.call_public_method( + self.contract_id1, + 'call_itself_through_double_proxy_same', + self.create_context(), + ) diff --git a/tests/nanocontracts/test_syscalls_in_view.py b/tests/nanocontracts/test_syscalls_in_view.py index e5d2ec180..62b158c42 100644 --- a/tests/nanocontracts/test_syscalls_in_view.py +++ b/tests/nanocontracts/test_syscalls_in_view.py @@ -44,6 +44,10 @@ def get_contract_id(self) -> None: def get_blueprint_id(self) -> None: self.syscall.get_blueprint_id() + @view + def get_current_code_blueprint_id(self) -> None: + self.syscall.get_current_code_blueprint_id() + @view def get_balance_before_current_call(self) -> None: self.syscall.get_balance_before_current_call() @@ -168,6 +172,7 @@ def test_syscalls(self) -> None: 'can_melt_before_current_call', 'call_view_method', 'get_contract', + 'get_current_code_blueprint_id', } for method_name, method in BlueprintEnvironment.__dict__.items():