Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHIA-711] Add WalletActionScope #18125

Merged
merged 14 commits into from
Jun 24, 2024
80 changes: 80 additions & 0 deletions chia/_tests/wallet/test_wallet_action_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Tuple

import pytest
from chia_rs import G2Element

from chia._tests.cmds.wallet.test_consts import STD_TX
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.spend_bundle import SpendBundle
from chia.wallet.signer_protocol import SigningResponse
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.wallet_action_scope import WalletSideEffects
from chia.wallet.wallet_state_manager import WalletStateManager

MOCK_SR = SigningResponse(b"hey", bytes32([0] * 32))
MOCK_SB = SpendBundle([], G2Element())


def test_back_and_forth_serialization() -> None:
assert bytes(WalletSideEffects()) == b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
assert WalletSideEffects.from_bytes(bytes(WalletSideEffects())) == WalletSideEffects()
assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX], [MOCK_SR], [MOCK_SB]))) == WalletSideEffects(
[STD_TX], [MOCK_SR], [MOCK_SB]
)
assert WalletSideEffects.from_bytes(
bytes(WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB]))
) == WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB])


@dataclass
class MockWalletStateManager:
most_recent_call: Optional[
Tuple[List[TransactionRecord], bool, bool, bool, List[SigningResponse], List[SpendBundle]]
] = None

async def add_pending_transactions(
self,
txs: List[TransactionRecord],
push: bool,
merge_spends: bool,
sign: bool,
additional_signing_responses: List[SigningResponse],
extra_spends: List[SpendBundle],
) -> List[TransactionRecord]:
self.most_recent_call = (txs, push, merge_spends, sign, additional_signing_responses, extra_spends)
return txs


MockWalletStateManager.new_action_scope = WalletStateManager.new_action_scope # type: ignore[attr-defined]


@pytest.mark.anyio
async def test_wallet_action_scope() -> None:
wsm = MockWalletStateManager()
async with wsm.new_action_scope( # type: ignore[attr-defined]
push=True,
merge_spends=False,
sign=True,
additional_signing_responses=[],
extra_spends=[],
) as action_scope:
async with action_scope.use() as interface:
interface.side_effects.transactions = [STD_TX]

with pytest.raises(RuntimeError):
action_scope.side_effects

assert action_scope.side_effects.transactions == [STD_TX]
assert wsm.most_recent_call == ([STD_TX], True, False, True, [], [])

async with wsm.new_action_scope( # type: ignore[attr-defined]
push=False, merge_spends=True, sign=True, additional_signing_responses=[]
) as action_scope:
async with action_scope.use() as interface:
interface.side_effects.transactions = []

assert action_scope.side_effects.transactions == []
assert wsm.most_recent_call == ([], False, True, True, [], [])
111 changes: 110 additions & 1 deletion chia/_tests/wallet/test_wallet_state_manager.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from __future__ import annotations

from contextlib import asynccontextmanager
from typing import AsyncIterator
from typing import AsyncIterator, List

import pytest
from chia_rs import G2Element

from chia._tests.environments.wallet import WalletTestFramework
from chia._tests.util.setup_nodes import OldSimulatorsAndWallets
from chia.protocols.wallet_protocol import CoinState
from chia.server.outbound_message import NodeType
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_spend import make_spend
from chia.types.peer_info import PeerInfo
from chia.types.spend_bundle import SpendBundle
from chia.util.ints import uint32, uint64
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.transaction_type import TransactionType
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_state_manager import WalletStateManager

Expand Down Expand Up @@ -95,3 +102,105 @@ async def test_determine_coin_type(simulator_and_wallet: OldSimulatorsAndWallets
assert (None, None) == await wallet_state_manager.determine_coin_type(
peer, CoinState(Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), uint32(0), uint32(0)), None
)


@pytest.mark.parametrize(
"wallet_environments",
[{"num_environments": 1, "blocks_needed": [1], "trusted": True, "reuse_puzhash": True}],
indirect=True,
)
@pytest.mark.limit_consensus_modes(reason="irrelevant")
@pytest.mark.anyio
async def test_commit_transactions_to_db(wallet_environments: WalletTestFramework) -> None:
env = wallet_environments.environments[0]
wsm = env.wallet_state_manager

coins = list(
await wsm.main_wallet.select_coins(
uint64(2_000_000_000_000), coin_selection_config=wallet_environments.tx_config.coin_selection_config
)
)
[tx1] = await wsm.main_wallet.generate_signed_transaction(
uint64(0),
bytes32([0] * 32),
wallet_environments.tx_config,
coins={coins[0]},
)
[tx2] = await wsm.main_wallet.generate_signed_transaction(
uint64(0),
bytes32([0] * 32),
wallet_environments.tx_config,
coins={coins[1]},
)

def flatten_spend_bundles(txs: List[TransactionRecord]) -> List[SpendBundle]:
return [tx.spend_bundle for tx in txs if tx.spend_bundle is not None]

assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)
new_txs = await wsm.add_pending_transactions(
[tx1, tx2],
push=False,
merge_spends=False,
sign=False,
extra_spends=[],
)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 2
for bundle in bundles:
assert bundle.aggregated_signature == G2Element()
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)

extra_coin_spend = make_spend(
Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), Program.to(1), Program.to([None])
)
extra_spend = SpendBundle([extra_coin_spend], G2Element())

new_txs = await wsm.add_pending_transactions(
[tx1, tx2],
push=False,
merge_spends=False,
sign=False,
extra_spends=[extra_spend],
)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 2
for bundle in bundles:
assert bundle.aggregated_signature == G2Element()
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)
assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends]

new_txs = await wsm.add_pending_transactions(
[tx1, tx2],
push=False,
merge_spends=True,
sign=False,
extra_spends=[extra_spend],
)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 1
for bundle in bundles:
assert bundle.aggregated_signature == G2Element()
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)
assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends]

[tx1, tx2] = await wsm.add_pending_transactions([tx1, tx2], push=True, merge_spends=True, sign=True)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 1
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 2
)

await wallet_environments.full_node.wait_transaction_records_entered_mempool([tx1, tx2])
95 changes: 95 additions & 0 deletions chia/wallet/wallet_action_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, AsyncIterator, List, Optional, cast

from chia.types.spend_bundle import SpendBundle
from chia.util.action_scope import ActionScope
from chia.wallet.signer_protocol import SigningResponse
from chia.wallet.transaction_record import TransactionRecord

if TYPE_CHECKING:
# Avoid a circular import here
from chia.wallet.wallet_state_manager import WalletStateManager


@dataclass
class WalletSideEffects:
transactions: List[TransactionRecord] = field(default_factory=list)
signing_responses: List[SigningResponse] = field(default_factory=list)
extra_spends: List[SpendBundle] = field(default_factory=list)

def __bytes__(self) -> bytes:
blob = b""
blob += len(self.transactions).to_bytes(4, "big")
for tx in self.transactions:
tx_bytes = bytes(tx)
blob += len(tx_bytes).to_bytes(4, "big") + tx_bytes
blob += len(self.signing_responses).to_bytes(4, "big")
for sr in self.signing_responses:
sr_bytes = bytes(sr)
blob += len(sr_bytes).to_bytes(4, "big") + sr_bytes
blob += len(self.extra_spends).to_bytes(4, "big")
for sb in self.extra_spends:
sb_bytes = bytes(sb)
blob += len(sb_bytes).to_bytes(4, "big") + sb_bytes
return blob

@classmethod
def from_bytes(cls, blob: bytes) -> WalletSideEffects:
instance = cls()
while blob != b"":
tx_len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
for _ in range(0, tx_len_prefix):
len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
instance.transactions.append(TransactionRecord.from_bytes(blob[:len_prefix]))
blob = blob[len_prefix:]
sr_len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
for _ in range(0, sr_len_prefix):
len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
instance.signing_responses.append(SigningResponse.from_bytes(blob[:len_prefix]))
blob = blob[len_prefix:]
sb_len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
for _ in range(0, sb_len_prefix):
len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
instance.extra_spends.append(SpendBundle.from_bytes(blob[:len_prefix]))
blob = blob[len_prefix:]

return instance


WalletActionScope = ActionScope[WalletSideEffects]


@contextlib.asynccontextmanager
async def new_wallet_action_scope(
wallet_state_manager: WalletStateManager,
push: bool = False,
merge_spends: bool = True,
sign: Optional[bool] = None,
additional_signing_responses: List[SigningResponse] = [],
extra_spends: List[SpendBundle] = [],
) -> AsyncIterator[WalletActionScope]:
async with ActionScope.new_scope(WalletSideEffects) as self:
self = cast(WalletActionScope, self)
async with self.use() as interface:
interface.side_effects.signing_responses = additional_signing_responses.copy()
interface.side_effects.extra_spends = extra_spends.copy()

yield self

self.side_effects.transactions = await wallet_state_manager.add_pending_transactions(
self.side_effects.transactions,
push=push,
merge_spends=merge_spends,
sign=sign,
additional_signing_responses=self.side_effects.signing_responses,
extra_spends=self.side_effects.extra_spends,
)
Loading
Loading