Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions hathor/nanocontracts/custom_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __call__(
...


def _generate_restriced_import_function(allowed_imports: dict[str, set[str]]) -> ImportFunction:
def _generate_restricted_import_function(allowed_imports: dict[str, set[str]]) -> ImportFunction:
"""Returns a function equivalent to builtins.__import__ but that will only import `allowed_imports`"""
@_wraps(builtins.__import__)
def __import__(
Expand All @@ -228,7 +228,7 @@ def __import__(
fromlist: Sequence[str] = (),
level: int = 0,
) -> types.ModuleType:
if level > 0:
if level != 0:
raise ImportError('Relative imports are not allowed')
if not fromlist and name != 'typing':
# XXX: typing is allowed here because Foo[T] triggers a __import__('typing', fromlist=None) for some reason
Expand Down Expand Up @@ -329,7 +329,7 @@ def filter(function: None | Callable[[T], object], iterable: Iterable[T]) -> Ite
# XXX: will trigger the execution of the imported module
# (name: str, globals: Mapping[str, object] | None = None, locals: Mapping[str, object] | None = None,
# fromlist: Sequence[str] = (), level: int = 0) -> types.ModuleType
'__import__': _generate_restriced_import_function(ALLOWED_IMPORTS),
'__import__': _generate_restricted_import_function(ALLOWED_IMPORTS),

# XXX: also required to declare classes
# XXX: this would be '__main__' for a module that is loaded as the main entrypoint, and the module name otherwise,
Expand Down
10 changes: 9 additions & 1 deletion tests/nanocontracts/blueprints/unittest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from io import TextIOWrapper
from os import PathLike

from hathor.conf.settings import HATHOR_TOKEN_UID
Expand Down Expand Up @@ -88,8 +89,15 @@ def _register_blueprint_class(
def register_blueprint_file(self, path: PathLike[str], blueprint_id: BlueprintId | None = None) -> BlueprintId:
"""Register a blueprint file with an optional id, allowing contracts to be created from it."""
with open(path, 'r') as f:
code = Code.from_python_code(f.read(), self._settings)
return self.register_blueprint_contents(f, blueprint_id)

def register_blueprint_contents(
self,
contents: TextIOWrapper,
blueprint_id: BlueprintId | None = None,
) -> BlueprintId:
"""Register blueprint contents with an optional id, allowing contracts to be created from it."""
code = Code.from_python_code(contents.read(), self._settings)
verifier = OnChainBlueprintVerifier(settings=self._settings)
ocb = OnChainBlueprint(hash=b'', code=code)
verifier.verify_code(ocb)
Expand Down
71 changes: 71 additions & 0 deletions tests/nanocontracts/test_custom_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 Hathor Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from io import StringIO
from textwrap import dedent
from unittest.mock import ANY, Mock, call

from hathor.nanocontracts.custom_builtins import EXEC_BUILTINS
from tests.nanocontracts.blueprints.unittest import BlueprintTestCase


class TestCustomImport(BlueprintTestCase):
def test_custom_import(self) -> None:
"""Guarantee our custom import function is being called, instead of the builtin one."""
contract_id = self.gen_random_contract_id()
blueprint = '''
from hathor.nanocontracts import Blueprint
from hathor.nanocontracts.context import Context
from hathor.nanocontracts.types import public

class MyBlueprint(Blueprint):
@public
def initialize(self, ctx: Context) -> None:
from math import ceil, floor
from collections import OrderedDict
from hathor.nanocontracts.exception import NCFail
from hathor.nanocontracts.types import NCAction, NCActionType

__blueprint__ = MyBlueprint
'''

# Wrap our custom builtin so we can spy its calls
wrapped_import_function = Mock(wraps=EXEC_BUILTINS['__import__'])
EXEC_BUILTINS['__import__'] = wrapped_import_function

# Before being used, the function is uncalled
wrapped_import_function.assert_not_called()

# During blueprint registration, the function is called for each import at the module level.
# This happens twice, once during verification and once during the actual registration.
blueprint_id = self.register_blueprint_contents(StringIO(dedent(blueprint)))
module_level_calls = [
call('hathor.nanocontracts', ANY, ANY, ('Blueprint',), 0),
call('hathor.nanocontracts.context', ANY, ANY, ('Context',), 0),
call('hathor.nanocontracts.types', ANY, ANY, ('public',), 0),
]
assert wrapped_import_function.call_count == 2 * len(module_level_calls)
wrapped_import_function.assert_has_calls(2 * module_level_calls)
wrapped_import_function.reset_mock()

# During the call to initialize(), the function is called for each import on that method.
self.runner.create_contract(contract_id, blueprint_id, self.create_context())
method_level_imports = [
call('math', ANY, ANY, ('ceil', 'floor'), 0),
call('collections', ANY, ANY, ('OrderedDict',), 0),
call('hathor.nanocontracts.exception', ANY, ANY, ('NCFail',), 0),
call('hathor.nanocontracts.types', ANY, ANY, ('NCAction', 'NCActionType'), 0),
]
assert wrapped_import_function.call_count == len(method_level_imports)
wrapped_import_function.assert_has_calls(method_level_imports)
4 changes: 2 additions & 2 deletions tests/nanocontracts/test_exposed_properties.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Iterator
from importlib import import_module
from sys import version_info
from types import MethodType
from typing import Any
Expand Down Expand Up @@ -314,11 +313,12 @@ def check(self, ctx: Context) -> list[str]:
mutable_props.extend(search_writeable_properties(MyBlueprint, 'MyBlueprint'))
mutable_props.extend(search_writeable_properties(self, 'self'))
mutable_props.extend(search_writeable_properties(ctx, 'ctx'))
custom_import = EXEC_BUILTINS['__import__']
for module_name, import_names in ALLOWED_IMPORTS.items():
if module_name == 'typing':
# FIXME: typing module causes problems for some reason
continue
module = import_module(module_name)
module = custom_import(module_name, fromlist=list(import_names))
for import_name in import_names:
obj = getattr(module, import_name)
obj_name = f'{module_name}.{import_name}'
Expand Down