Skip to content

Commit

Permalink
[data transformer::01] Adapt transformer and use it in `deploy_contra…
Browse files Browse the repository at this point in the history
…ct` (#353)

* add starknet_py

* add data transformer facade

* connect data transformer without testing

* use data transformer in deploy contract

* lint

* update docs

* fix test
  • Loading branch information
kasperski95 authored Jun 10, 2022
1 parent aeb0a86 commit e9f701d
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 25 deletions.
40 changes: 39 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 23 additions & 3 deletions protostar/commands/test/test_execution_environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from collections.abc import Mapping
from copy import deepcopy
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Set
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Union

from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner
from starkware.starknet.public.abi import get_selector_from_name
Expand Down Expand Up @@ -33,7 +35,9 @@
SimpleReportedException,
StarknetRevertableException,
)
from protostar.utils.data_transformer_facade import DataTransformerFacade
from protostar.utils.modules import replace_class
from protostar.utils.starknet_compilation import StarknetCompiler

logger = getLogger()

Expand All @@ -57,33 +61,40 @@ def class_hash(self):


class TestExecutionEnvironment:
# pylint: disable=too-many-arguments
def __init__(
self,
include_paths: List[str],
forkable_starknet: ForkableStarknet,
test_contract: StarknetContract,
test_context: TestContext,
starknet_compiler: StarknetCompiler,
):
self.starknet = forkable_starknet
self.test_contract: StarknetContract = test_contract
self.test_context = test_context
self._expected_error: Optional[RevertableException] = None
self._include_paths = include_paths
self._test_finish_hooks: Set[Callable[[], None]] = set()
self._starknet_compiler = starknet_compiler

@classmethod
async def from_test_suite_definition(
cls,
starknet_compiler: StarknetCompiler,
test_suite_definition: ContractClass,
include_paths: Optional[List[str]] = None,
):
starknet = await ForkableStarknet.empty()

starknet_contract = await starknet.deploy(contract_class=test_suite_definition)

return cls(
include_paths or [],
forkable_starknet=starknet,
test_contract=await starknet.deploy(contract_class=test_suite_definition),
test_contract=starknet_contract,
test_context=TestContext(),
starknet_compiler=starknet_compiler,
)

def fork(self):
Expand All @@ -93,12 +104,14 @@ def fork(self):
forkable_starknet=starknet_fork,
test_contract=starknet_fork.copy_and_adapt_contract(self.test_contract),
test_context=deepcopy(self.test_context),
starknet_compiler=self._starknet_compiler,
)
return new_env

def deploy_in_env(
self, contract_path: str, constructor_calldata: Optional[List[int]] = None
):

contract = DeployedContract(
asyncio.run(
self.starknet.deploy(
Expand Down Expand Up @@ -310,8 +323,15 @@ def compare_expected_and_emitted_events():

@register_cheatcode
def deploy_contract(
contract_path: str, constructor_calldata: Optional[List[int]] = None
contract_path: str,
constructor_calldata: Optional[Union[List[int], Dict]] = None,
):
if isinstance(constructor_calldata, Mapping):
fn_name = "constructor"
constructor_calldata = DataTransformerFacade.from_contract_path(
Path(contract_path), self._starknet_compiler
).from_python(fn_name, **constructor_calldata)

return self.deploy_in_env(contract_path, constructor_calldata)

@register_cheatcode
Expand Down
12 changes: 7 additions & 5 deletions protostar/commands/test/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def __init__(
if include_paths:
self.include_paths.extend(include_paths)

self.starknet_compiler = StarknetCompiler(
include_paths=self.include_paths,
disable_hint_validation=True,
)

@dataclass
class WorkerArgs:
test_suite: TestSuite
Expand All @@ -56,10 +61,7 @@ async def run_test_suite(self, test_suite: TestSuite):
self.include_paths is not None
), "Uninitialized paths list in test runner"

compiled_test = StarknetCompiler(
include_paths=self.include_paths,
disable_hint_validation=True,
).compile_preprocessed_contract(
compiled_test = self.starknet_compiler.compile_preprocessed_contract(
test_suite.preprocessed_contract, add_debug_info=True
)

Expand Down Expand Up @@ -95,7 +97,7 @@ async def _run_test_suite(

try:
env_base = await TestExecutionEnvironment.from_test_suite_definition(
test_contract, self.include_paths
self.starknet_compiler, test_contract, self.include_paths
)

if test_suite.setup_fn_name:
Expand Down
1 change: 1 addition & 0 deletions protostar/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from protostar.utils.config import Project
from protostar.utils.create_and_commit_sample_file import create_and_commit_sample_file
from protostar.utils.data_transformer_facade import DataTransformerFacade
from protostar.utils.log_color_provider import log_color_provider
from protostar.utils.package_info import (
extract_info_from_repo_id,
Expand Down
43 changes: 43 additions & 0 deletions protostar/utils/data_transformer_facade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pathlib import Path
from typing import List

from starknet_py.utils.data_transformer.data_transformer import (
ABIFunctionEntry,
DataTransformer,
)
from starkware.starknet.public.abi import AbiType
from starkware.starknet.public.abi_structs import identifier_manager_from_abi

from protostar.utils.starknet_compilation import StarknetCompiler


class FunctionNotFoundException(BaseException):
pass


class DataTransformerFacade:
@classmethod
def from_contract_path(
cls, path: Path, starknet_compiler: StarknetCompiler
) -> "DataTransformerFacade":
preprocessed = starknet_compiler.preprocess_contract(path)
return cls(preprocessed.abi)

def __init__(self, contract_abi: AbiType) -> None:
self._contract_abi = contract_abi
self._identifier_manager = identifier_manager_from_abi(contract_abi)

def _get_function_abi(self, fn_name: str) -> ABIFunctionEntry:
for item in self._contract_abi:
if (item["type"] == "function" or item["type"] == "constructor") and item[
"name"
] == fn_name:
return item
raise FunctionNotFoundException(f"Couldn't find a function '{fn_name}'")

def from_python(self, fn_name: str, *args, **kwargs) -> List[int]:
data_transformer = DataTransformer(
self._get_function_abi(fn_name),
identifier_manager_from_abi(self._contract_abi),
)
return data_transformer.from_python(*args, **kwargs)[0]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ tomli = "<2.0.0"
tomli-w = "^1.0.0"
tqdm = "^4.64.0"
typing-extensions = "^4.0.1"
"starknet.py" = "^0.3.0-alpha.0"

[tool.poetry.dev-dependencies]
black = "^22.1.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ func get_balance{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_p
let (res) = balance.read()
return (res)
end

@constructor
func constructor{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}():
balance.write(0)
return ()
end
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
%lang starknet
from starkware.cairo.common.uint256 import Uint256, uint256_add
from starkware.starknet.common.syscalls import get_contract_address
from starkware.cairo.common.uint256 import Uint256

# Check if importing from root directory is possible

Expand All @@ -17,9 +18,6 @@ end
@contract_interface
namespace BasicWithConstructor:
func increase_balance(amount : Uint256):
end
func get_balance() -> (res : Uint256):
end
Expand Down Expand Up @@ -69,17 +67,45 @@ func test_missing_logic_contract{syscall_ptr : felt*, range_check_ptr}():
end

@external
func test_deploy_contract_with_args_in_constructor{syscall_ptr : felt*, range_check_ptr}():
func test_passing_constructor_data_as_list{syscall_ptr : felt*, range_check_ptr}():
alloc_locals
local deployed_contract_address : felt
let (contract_address) = get_contract_address()

%{
ids.deployed_contract_address = deploy_contract("./tests/integration/cheatcodes/deploy_contract/basic_with_constructor.cairo",
[42, 0, ids.contract_address]
).contract_address
%}

let (balance) = BasicWithConstructor.get_balance(deployed_contract_address)
let (id) = BasicWithConstructor.get_id(deployed_contract_address)

assert balance.low = 42
assert balance.high = 0
assert id = contract_address

return ()
end

@external
func test_data_transformation{syscall_ptr : felt*, range_check_ptr}():
alloc_locals
local deployed_contract_address : felt
let (contract_address) = get_contract_address()

%{
ids.deployed_contract_address = deploy_contract("./tests/integration/cheatcodes/deploy_contract/basic_with_constructor.cairo",
{ "initial_balance": 42, "contract_id": ids.contract_address }
).contract_address
%}

local contract_a_address : felt
%{ ids.contract_a_address = deploy_contract("./tests/integration/cheatcodes/deploy_contract/basic_with_constructor.cairo", [100, 0, 1]).contract_address %}
let (balance) = BasicWithConstructor.get_balance(deployed_contract_address)
let (id) = BasicWithConstructor.get_id(deployed_contract_address)

let (res) = BasicWithConstructor.get_balance(contract_address=contract_a_address)
assert res.low = 100
assert res.high = 0
assert balance.low = 42
assert balance.high = 0
assert id = contract_address

let (id) = BasicWithConstructor.get_id(contract_address=contract_a_address)
assert id = 1
return ()
end
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,40 @@
from tests.integration.conftest import assert_cairo_test_cases


@pytest.fixture(name="target")
def target_fixture() -> str:
return f"{Path(__file__).parent}/deploy_contract_test.cairo"


@pytest.mark.asyncio
async def test_deploy_contract(mocker):
async def test_deploy_contract(mocker, target: str):
testing_summary = await TestCommand(
project=mocker.MagicMock(),
protostar_directory=mocker.MagicMock(),
).test(targets=[f"{Path(__file__).parent}/deploy_contract_test.cairo"])
).test(targets=[target], ignored_targets=[f"{target}::test_data_transformation"])

assert_cairo_test_cases(
testing_summary,
expected_passed_test_cases_names=[
"test_proxy_contract",
"test_missing_logic_contract",
"test_deploy_contract_with_args_in_constructor",
"test_passing_constructor_data_as_list",
],
expected_failed_test_cases_names=[],
)


@pytest.mark.asyncio
async def test_data_transformation(mocker, target):
testing_summary = await TestCommand(
project=mocker.MagicMock(),
protostar_directory=mocker.MagicMock(),
).test(targets=[f"{target}::test_data_transformation"])

assert_cairo_test_cases(
testing_summary,
expected_passed_test_cases_names=[
"test_data_transformation",
],
expected_failed_test_cases_names=[],
)
Loading

0 comments on commit e9f701d

Please sign in to comment.