Skip to content

Commit

Permalink
Add message to l2 cheatcode (#940)
Browse files Browse the repository at this point in the history
Docs in a separate PR
  • Loading branch information
Arcticae authored Oct 7, 2022
1 parent d4d1991 commit c05ad44
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 3 deletions.
85 changes: 85 additions & 0 deletions protostar/testing/cheatcodes/send_message_to_l2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Callable, Any, Mapping, Optional, List

from starkware.starknet.business_logic.execution.execute_entry_point import (
ExecuteEntryPoint,
)
from starkware.starknet.business_logic.execution.objects import CallType
from starkware.starknet.public.abi import get_selector_from_name, AbiType
from starkware.starknet.services.api.contract_class import EntryPointType

from protostar.starknet import Cheatcode, CheatcodeException
from protostar.starknet.data_transformer import (
CairoOrPythonData,
from_python_transformer,
)


def get_calldata_for_execution(
payload: CairoOrPythonData,
l1_sender_address: int,
abi: AbiType,
fn_name: str,
) -> List[int]:
if isinstance(payload, Mapping):
transformer = from_python_transformer(abi, fn_name, "inputs")
return transformer(
{
**payload,
"from_address": l1_sender_address,
}
)
return [l1_sender_address, *payload]


def get_contract_l1_handlers_names(abi: AbiType) -> List[str]:
return [fn["name"] for fn in abi if fn["type"] == "l1_handler"]


class SendMessageToL2Cheatcode(Cheatcode):
@property
def name(self) -> str:
return "send_message_to_l2"

def build(self) -> Callable[..., Any]:
return self.send_message_to_l2

def send_message_to_l2(
self,
fn_name: str,
from_address: int = 0,
to_address: Optional[int] = None,
payload: Optional[CairoOrPythonData] = None,
) -> None:
to_address = to_address if to_address else self.contract_address

class_hash = self.state.get_class_hash_at(to_address)
contract_class = self.state.get_contract_class(class_hash)

if not contract_class.abi:
raise CheatcodeException(
self,
"Contract (address: {hex(contract_address)}) doesn't have any entrypoints",
)

if fn_name not in get_contract_l1_handlers_names(contract_class.abi):
raise CheatcodeException(
self,
f"L1 handler {fn_name} was not found in contract (address: {hex(to_address)}) ABI",
)

calldata = get_calldata_for_execution(
payload or [], from_address, contract_class.abi, fn_name
)

self.execute_entry_point(
ExecuteEntryPoint.create(
contract_address=to_address,
calldata=calldata,
entry_point_selector=get_selector_from_name(fn_name),
# FIXME(arcticae): This might be wrong, since the caller might be some starknet OS specific address
caller_address=from_address,
entry_point_type=EntryPointType.L1_HANDLER,
call_type=CallType.DELEGATE,
class_hash=class_hash,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
WarpCheatcode,
)
from protostar.testing.cheatcodes.reflect.cairo_struct import CairoStructHintLocal
from protostar.testing.cheatcodes.send_message_to_l2 import SendMessageToL2Cheatcode
from protostar.testing.starkware.test_execution_state import TestExecutionState
from protostar.testing.test_context import TestContextHintLocal

Expand Down Expand Up @@ -54,6 +55,7 @@ def build_cheatcodes(
StoreCheatcode(syscall_dependencies),
LoadCheatcode(syscall_dependencies),
ReflectCheatcode(syscall_dependencies),
SendMessageToL2Cheatcode(syscall_dependencies),
]

def build_hint_locals(self) -> List[HintLocal]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
%lang starknet
from starkware.cairo.common.cairo_builtins import HashBuiltin


@storage_var
func state() -> (res: felt) {
}

@l1_handler
func existing_handler{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(from_address: felt, value: felt){
state.write(value);
return ();
}

@view
func get_state{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}() -> (res: felt) {
let (res) = state.read();
return (res=res);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
from pathlib import Path

from tests.integration.conftest import (
RunCairoTestRunnerFixture,
assert_cairo_test_cases,
CreateProtostarProjectFixture,
)


def test_l1_to_l2_message_cheatcode(
run_cairo_test_runner: RunCairoTestRunnerFixture,
create_protostar_project: CreateProtostarProjectFixture,
):
with create_protostar_project() as protostar:
contracts_sources_path = Path(__file__).parent

protostar.create_files(
{
"tests/test_main.cairo": contracts_sources_path / "simple_l1_handler_test.cairo",
"src/main.cairo": contracts_sources_path / "external_contract_with_l1_handler.cairo",
}
)
protostar.build_sync()

testing_summary = asyncio.run(run_cairo_test_runner(Path(".")))

assert_cairo_test_cases(
testing_summary,
expected_passed_test_cases_names=[
"test_existing_self_l1_handle_call",
"test_existing_self_l1_handle_call_w_transformer",
"test_existing_self_l1_handle_call_no_calldata",
"test_existing_self_l1_handle_call_custom_l1_sender_address",
"test_existing_external_contract_l1_handle_call",
],
expected_broken_test_cases_names=["test_non_existing_self_l1_handle_call"],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
%lang starknet
from starkware.cairo.common.cairo_builtins import HashBuiltin


@storage_var
func state() -> (res: felt) {
}

@l1_handler
func existing_handler{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(from_address: felt, value: felt){
state.write(value);
return ();
}

const ALLOWED_L1_SENDER_ADDRESS = 123;

@l1_handler
func existing_handler_verifying_sender_address{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(from_address: felt, value: felt){
assert from_address = ALLOWED_L1_SENDER_ADDRESS;
state.write(value);
return ();
}

const PREDEFINED_VALUE = 'somevalue';

@l1_handler
func existing_handler_no_calldata{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(from_address: felt){
state.write(PREDEFINED_VALUE);
return ();
}


@external
func test_existing_self_l1_handle_call{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(){
let STATE_AFTER = 'self_l1_handle_call';

%{ send_message_to_l2("existing_handler", payload=[ids.STATE_AFTER]) %}

let (state_after_l1_msg) = state.read();

assert state_after_l1_msg = STATE_AFTER;
return ();
}

@external
func test_existing_self_l1_handle_call_no_calldata{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(){
%{ send_message_to_l2("existing_handler_no_calldata") %}

let (state_after_l1_msg) = state.read();

assert state_after_l1_msg = PREDEFINED_VALUE;
return ();
}

@external
func test_existing_self_l1_handle_call_w_transformer{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(){
let STATE_AFTER = 'self_l1_handle_call';

%{ send_message_to_l2("existing_handler", payload={"value": ids.STATE_AFTER}) %}

let (state_after_l1_msg) = state.read();

assert state_after_l1_msg = STATE_AFTER;
return ();
}

@external
func test_non_existing_self_l1_handle_call{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(){
let STATE_AFTER = 'self_l1_handle_call';

%{ send_message_to_l2("non_existing_handler", payload={"value": ids.STATE_AFTER}) %}
return ();
}

@external
func test_existing_self_l1_handle_call_custom_l1_sender_address{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(){
let STATE_AFTER = 'self_l1_handle_call';

%{
send_message_to_l2(
fn_name="existing_handler_verifying_sender_address",
payload={"value": ids.STATE_AFTER},
from_address=ids.ALLOWED_L1_SENDER_ADDRESS,
)
%}

let (state_after_l1_msg) = state.read();
assert state_after_l1_msg = STATE_AFTER;
return ();
}



@contract_interface
namespace ExternalContractInterface {
func get_state() -> (res: felt) {
}
}

@external
func test_existing_external_contract_l1_handle_call{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
}(){
alloc_locals;
local external_contract_address: felt;
let secret_value = 's3cr3t';
%{ ids.external_contract_address = deploy_contract("src/main.cairo").contract_address %}
%{
send_message_to_l2(
fn_name="existing_handler",
from_address=123,
payload=[ids.secret_value],
to_address=ids.external_contract_address,
)
%}

let (state) = ExternalContractInterface.get_state(contract_address=external_contract_address);

assert state = secret_value;
return ();
}
10 changes: 7 additions & 3 deletions tests/integration/protostar_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from argparse import Namespace
from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, cast, Union

from pytest_mock import MockerFixture
from starknet_py.net import KeyPair
Expand Down Expand Up @@ -208,8 +208,12 @@ def format_with_output(

return summary, output

def create_files(self, relative_path_str_to_content: Dict[str, str]) -> None:
for relative_path_str, content in relative_path_str_to_content.items():
def create_files(self, relative_path_str_to_file: Dict[str, Union[str, Path]]) -> None:
for relative_path_str, file in relative_path_str_to_file.items():
if isinstance(file, Path):
content = file.read_text("utf-8")
else:
content = file
self._save_file(self._project_root_path / relative_path_str, content)

def create_migration_file(
Expand Down

0 comments on commit c05ad44

Please sign in to comment.