Skip to content

Commit

Permalink
Fix Protostar freezing (#311)
Browse files Browse the repository at this point in the history
* fix freezing

* handle cheatcode errors

* test freezing
  • Loading branch information
kasperski95 authored May 27, 2022
1 parent c793d4f commit 0e58dfa
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 40 deletions.
2 changes: 1 addition & 1 deletion protostar/commands/test/starkware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from protostar.commands.test.starkware.cheatable_syscall_handler import (
CheatableHintsWhitelist,
CheatableSysCallHandler,
CheatcodeException,
CheatableSysCallHandlerException,
)
from protostar.commands.test.starkware.forkable_starknet import ForkableStarknet
14 changes: 8 additions & 6 deletions protostar/commands/test/starkware/cheatable_syscall_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
SelectorType = int


class CheatcodeException(BaseException):
pass
class CheatableSysCallHandlerException(BaseException):
def __init__(self, message: str):
self.message = message
super().__init__(message)


class CheatableSysCallHandler(BusinessLogicSysCallHandler):
Expand Down Expand Up @@ -52,7 +54,7 @@ def set_caller_address(
else self.contract_address
)
if target in self.state.pranked_contracts_map:
raise CheatcodeException(
raise CheatableSysCallHandlerException(
f"Contract with address {target} has been already pranked"
)
self.state.pranked_contracts_map[target] = addr
Expand All @@ -64,7 +66,7 @@ def reset_caller_address(self, target_contract_address: Optional[int] = None):
else self.contract_address
)
if target not in self.state.pranked_contracts_map:
raise CheatcodeException(
raise CheatableSysCallHandlerException(
f"Contract with address {target} has not been pranked"
)
del self.state.pranked_contracts_map[target]
Expand Down Expand Up @@ -95,11 +97,11 @@ def register_mock_call(

def unregister_mock_call(self, contract_address: AddressType, selector: int):
if contract_address not in self.mocked_calls:
raise CheatcodeException(
raise CheatableSysCallHandlerException(
f"Contract {contract_address} doesn't have mocked selectors."
)
if selector not in self.mocked_calls[contract_address]:
raise CheatcodeException(
raise CheatableSysCallHandlerException(
f"Couldn't find mocked selector {selector} for an address {contract_address}."
)
del self.mocked_calls[contract_address][selector]
Expand Down
28 changes: 22 additions & 6 deletions protostar/commands/test/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from attr import dataclass
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo
from starkware.starkware_utils.error_handling import StarkException

from protostar.commands.test.test_environment_exceptions import ReportedException
from protostar.protostar_exception import UNEXPECTED_PROTOSTAR_ERROR_MSG
from protostar.utils.log_color_provider import log_color_provider


Expand Down Expand Up @@ -47,10 +47,26 @@ def __str__(self) -> str:
@dataclass(frozen=True)
class BrokenTestSuite(TestCaseResult):
test_case_names: List[str]
exception: StarkException
exception: BaseException

def __str__(self) -> str:
result: List[str] = []
result.append(f"[{log_color_provider.colorize('RED', 'BROKEN')}]")
result.append(f"{self.get_formatted_file_path()}")
return " ".join(result)
first_line: List[str] = []
first_line.append(f"[{log_color_provider.colorize('RED', 'BROKEN')}]")
first_line.append(f"{self.get_formatted_file_path()}")
result = [" ".join(first_line)]
result.append(str(self.exception))
return "\n".join(result)


@dataclass(frozen=True)
class UnexpectedExceptionTestSuiteResult(BrokenTestSuite):
def __str__(self) -> str:
first_line: List[str] = []
first_line.append(
f"[{log_color_provider.colorize('RED', 'UNEXPECTED_EXCEPTION')}]"
)
first_line.append(self.get_formatted_file_path())

result = [" ".join(first_line), UNEXPECTED_PROTOSTAR_ERROR_MSG]
result.append(str(self.exception))
return "\n".join(result)
16 changes: 16 additions & 0 deletions protostar/commands/test/test_environment_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@ def __str__(self) -> str:
return str(super().__repr__())


class CheatcodeException(ReportedException):
def __init__(self, cheatcode_name: str, message: str):
self.cheatcode_name = cheatcode_name
self.message = message
super().__init__(message)

def __str__(self):
lines: List[str] = []
lines.append(f"Incorrect usage of `{self.cheatcode_name}` cheatcode")
lines.append(self.message)
return "\n".join(lines)

def __reduce__(self):
return type(self), (self.cheatcode_name, self.message)


class RevertableException(ReportedException):
"""
This exception is used by `except_revert` logic.
Expand Down
36 changes: 26 additions & 10 deletions protostar/commands/test/test_execution_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from protostar.commands.test.expected_event import ExpectedEvent
from protostar.commands.test.starkware.cheatable_syscall_handler import (
CheatableSysCallHandler,
CheatableSysCallHandlerException,
)
from protostar.commands.test.starkware.forkable_starknet import ForkableStarknet
from protostar.commands.test.test_environment_exceptions import (
CheatcodeException,
ExpectedRevertException,
ExpectedRevertMismatchException,
ReportedException,
Expand Down Expand Up @@ -178,14 +180,20 @@ def warp(blk_timestamp: int):
def start_prank(
caller_address: int, target_contract_address: Optional[int] = None
):
cheatable_syscall_handler.set_caller_address(
caller_address, target_contract_address=target_contract_address
)
try:
cheatable_syscall_handler.set_caller_address(
caller_address, target_contract_address=target_contract_address
)
except CheatableSysCallHandlerException as err:
raise CheatcodeException("start_prank", err.message) from err

def stop_started_prank():
cheatable_syscall_handler.reset_caller_address(
target_contract_address=target_contract_address
)
try:
cheatable_syscall_handler.reset_caller_address(
target_contract_address=target_contract_address
)
except CheatableSysCallHandlerException as err:
raise CheatcodeException("start_prank", err.message) from err

return stop_started_prank

Expand All @@ -194,9 +202,12 @@ def stop_prank(target_contract_address: Optional[int] = None):
logger.warning(
"Using stop_prank() is deprecated, instead use a function returned by start_prank()"
)
cheatable_syscall_handler.reset_caller_address(
target_contract_address=target_contract_address
)
try:
cheatable_syscall_handler.reset_caller_address(
target_contract_address=target_contract_address
)
except CheatableSysCallHandlerException as err:
raise CheatcodeException("stop_prank", err.message) from err

@register_cheatcode
def mock_call(contract_address: int, fn_name: str, ret_data: List[int]):
Expand All @@ -208,7 +219,12 @@ def mock_call(contract_address: int, fn_name: str, ret_data: List[int]):
@register_cheatcode
def clear_mock_call(contract_address: int, fn_name: str):
selector = get_selector_from_name(fn_name)
cheatable_syscall_handler.unregister_mock_call(contract_address, selector)
try:
cheatable_syscall_handler.unregister_mock_call(
contract_address, selector
)
except CheatableSysCallHandlerException as err:
raise CheatcodeException("stop_prank", err.message) from err

@register_cheatcode
def expect_events(
Expand Down
37 changes: 26 additions & 11 deletions protostar/commands/test/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BrokenTestSuite,
FailedTestCase,
PassedTestCase,
UnexpectedExceptionTestSuiteResult,
)
from protostar.commands.test.test_environment_exceptions import ReportedException
from protostar.commands.test.test_execution_environment import TestExecutionEnvironment
Expand Down Expand Up @@ -50,19 +51,33 @@ def worker(cls, args: "TestRunner.WorkerArgs"):
)

async def run_test_suite(self, test_suite: TestSuite):
assert self.include_paths is not None, "Uninitialized paths list in test runner"
try:
assert (
self.include_paths is not None
), "Uninitialized paths list in test runner"

compiled_test = StarknetCompiler(
include_paths=self.include_paths,
disable_hint_validation=True,
).compile_preprocessed_contract(
test_suite.preprocessed_contract, add_debug_info=True
)

compiled_test = StarknetCompiler(
include_paths=self.include_paths,
disable_hint_validation=True,
).compile_preprocessed_contract(
test_suite.preprocessed_contract, add_debug_info=True
)
await self._run_test_suite(
test_contract=compiled_test,
test_suite=test_suite,
)

await self._run_test_suite(
test_contract=compiled_test,
test_suite=test_suite,
)
# An unexpected exception in a worker should crash nor freeze the whole application
# pylint: disable=broad-except
except BaseException as ex:
self.queue.put(
UnexpectedExceptionTestSuiteResult(
file_path=test_suite.test_path,
test_case_names=test_suite.test_case_names,
exception=ex,
)
)

async def _run_test_suite(
self,
Expand Down
12 changes: 10 additions & 2 deletions protostar/protostar_exception.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
UNEXPECTED_PROTOSTAR_ERROR_MSG = (
"Unexpected Protostar error. Report it here:\n"
"https://github.com/software-mansion/protostar/issues\n"
)


class ProtostarException(Exception):
"This exception is nicely printed by protostar and results in non-zero exit code"
"""This exception is nicely printed by protostar and results in non-zero exit code"""

# Disabling pylint to narrow down types
# pylint: disable=useless-super-delegation
def __init__(self, message: str):
Expand All @@ -8,5 +15,6 @@ def __init__(self, message: str):


class ProtostarExceptionSilent(ProtostarException):
"This exception isn't printed but results in non-zero exit code"
"""This exception isn't printed but results in non-zero exit code"""

...
5 changes: 2 additions & 3 deletions protostar/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from protostar.cli import ArgumentParserFacade, ArgumentValueFromConfigProvider
from protostar.cli.cli_app import CLIApp
from protostar.protostar_cli import ConfigurationProfileCLISchema, ProtostarCLI
from protostar.protostar_exception import UNEXPECTED_PROTOSTAR_ERROR_MSG


def main(script_root: Path):
Expand All @@ -27,7 +28,5 @@ def main(script_root: Path):
except CLIApp.CommandNotFoundError:
parser.print_help()
except Exception as err:
print(
"Unexpected Protostar error. Report it here:\nhttps://github.com/software-mansion/protostar/issues\n"
)
print(UNEXPECTED_PROTOSTAR_ERROR_MSG)
raise err
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ func test_clearing_mocks{syscall_ptr : felt*, range_check_ptr}():
return ()
end

@view
func test_cannot_freeze_when_cheatcode_exception_is_raised{syscall_ptr : felt*, range_check_ptr}():
tempvar external_contract_address = EXTERNAL_CONTRACT_ADDRESS
%{ clear_mock_call(ids.external_contract_address, "get_felt") %}

return ()
end

# deploy_contract
@contract_interface
namespace BasicContract:
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/cheatcodes/other/test_other_cheatcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ async def test_other_cheatcodes(mocker):
"test_call_not_existing_contract",
"test_call_not_existing_contract_specific_error",
],
expected_failed_test_cases_names=[],
expected_failed_test_cases_names=[
"test_cannot_freeze_when_cheatcode_exception_is_raised"
],
)

0 comments on commit 0e58dfa

Please sign in to comment.