Skip to content

Commit

Permalink
[setup_state::01] Detect and try executing hook (#307)
Browse files Browse the repository at this point in the history
* add integration tests

* find setup_state function

* throw an exception if setup state is defined

* use state dict in tests
  • Loading branch information
kasperski95 authored May 27, 2022
1 parent 216514f commit c793d4f
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 6 deletions.
14 changes: 13 additions & 1 deletion protostar/commands/test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def _build_test_suite(
for test_case_name in test_case_names
if test_case_name == target_test_case_name
]

return TestSuite(
test_path=file_path,
test_case_names=test_case_names,
preprocessed_contract=preprocessed,
setup_state_fn_name=self._find_setup_state_hook_name(preprocessed),
)

def _get_test_suite_paths(self, target: Path) -> Generator[Path, None, None]:
Expand All @@ -131,7 +133,17 @@ def _get_test_suite_paths(self, target: Path) -> Generator[Path, None, None]:
def _collect_test_case_names(
self, preprocessed: StarknetPreprocessedProgram
) -> List[str]:
return self._starknet_compiler.get_function_names(preprocessed, prefix="test_")
return self._starknet_compiler.get_function_names(
preprocessed, predicate=lambda fn_name: fn_name.startswith("test_")
)

def _find_setup_state_hook_name(
self, preprocessed: StarknetPreprocessedProgram
) -> Optional[str]:
function_names = self._starknet_compiler.get_function_names(
preprocessed, predicate=lambda fn_name: fn_name == "setup_state"
)
return function_names[0] if len(function_names) > 0 else None

def _preprocess_contract(self, file_path: Path) -> StarknetPreprocessedProgram:
try:
Expand Down
21 changes: 20 additions & 1 deletion protostar/commands/test/test_collector_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import List, cast
from typing import Callable, List, cast
from unittest.mock import MagicMock

import pytest
Expand All @@ -14,6 +14,7 @@
TestCollector,
)
from protostar.commands.test.test_suite import TestSuite
from protostar.utils.starknet_compilation import StarknetCompiler


@pytest.fixture(name="project_root")
Expand Down Expand Up @@ -120,6 +121,24 @@ def test_collector_preprocess_contracts(
assert suite.preprocessed_contract == preprocessed_contract


def test_finding_setup_state_function(
starknet_compiler: StarknetCompiler, project_root: Path
):
def get_function_names(_, predicate: Callable[[str], bool]) -> List[str]:
return list(filter(predicate, ["test_main", "setup_state"]))

cast(
MagicMock, starknet_compiler.get_function_names
).side_effect = get_function_names
test_collector = TestCollector(starknet_compiler)

[suite] = test_collector.collect(
project_root / "foo" / "test_foo.cairo"
).test_suites

assert suite.setup_state_fn_name == "setup_state"


def test_logging_collected_one_test_suite_and_one_test_case(mocker: MockerFixture):
logger_mock = mocker.MagicMock()

Expand Down
4 changes: 4 additions & 0 deletions protostar/commands/test/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ async def _run_test_suite(
env_base = await TestExecutionEnvironment.empty(
test_contract, self.include_paths
)

if test_suite.setup_state_fn_name:
raise NotImplementedError()

except StarkException as err:
self.queue.put(
BrokenTestSuite(
Expand Down
4 changes: 3 additions & 1 deletion protostar/commands/test/test_suite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List
from typing import List, Optional

from starkware.starknet.compiler.starknet_preprocessor import (
StarknetPreprocessedProgram,
)
Expand All @@ -11,3 +12,4 @@ class TestSuite:
test_path: Path
preprocessed_contract: StarknetPreprocessedProgram
test_case_names: List[str]
setup_state_fn_name: Optional[str] = None
6 changes: 3 additions & 3 deletions protostar/utils/starknet_compilation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List
from typing import Callable, List

from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
from starkware.cairo.lang.compiler.cairo_compile import get_module_reader
Expand Down Expand Up @@ -82,10 +82,10 @@ def compile_contract(

@staticmethod
def get_function_names(
preprocessed: StarknetPreprocessedProgram, prefix: str
preprocessed: StarknetPreprocessedProgram, predicate: Callable[[str], bool]
) -> List[str]:
return [
el["name"]
for el in preprocessed.abi
if el["type"] == "function" and el["name"].startswith(prefix)
if el["type"] == "function" and predicate(el["name"])
]
17 changes: 17 additions & 0 deletions tests/integration/testing_hooks/testing_hooks_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
%lang starknet
from starkware.cairo.common.cairo_builtins import HashBuiltin

@view
func setup_state():
%{ state["contract"] = deploy_contract("./src/main.cairo") %}
return ()
end

@view
func test_contract_was_deployed_in_setup_state():
tempvar contract_address
%{ assert state["contract"].contract_address is not None %}

return ()
end
21 changes: 21 additions & 0 deletions tests/integration/testing_hooks/testing_hooks_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pathlib import Path

import pytest

from protostar.commands.test.test_command import TestCommand
from tests.integration.conftest import assert_cairo_test_cases


@pytest.mark.skip
@pytest.mark.asyncio
async def test_testing_hooks(mocker):
testing_summary = await TestCommand(
project=mocker.MagicMock(),
protostar_directory=mocker.MagicMock(),
).test(target=Path(__file__).parent / "testing_hooks_test.cairo")

assert_cairo_test_cases(
testing_summary,
expected_passed_test_cases_names=["test_contract_was_deployed_in_setup_state"],
expected_failed_test_cases_names=[],
)

0 comments on commit c793d4f

Please sign in to comment.