From 05a7bb830761c8aeb7fd121cf14839350489e886 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 07:53:21 -0700 Subject: [PATCH 01/12] Add the concept of 'action scopes' --- chia/_tests/util/test_action_scope.py | 122 +++++++++++++++++++ chia/util/action_scope.py | 168 ++++++++++++++++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 chia/_tests/util/test_action_scope.py create mode 100644 chia/util/action_scope.py diff --git a/chia/_tests/util/test_action_scope.py b/chia/_tests/util/test_action_scope.py new file mode 100644 index 000000000000..a00a5c4f710a --- /dev/null +++ b/chia/_tests/util/test_action_scope.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import AsyncIterator + +import pytest + +from chia.util.action_scope import ActionScope, StateInterface + + +@dataclass +class TestSideEffects: + buf: bytes = b"" + + def __bytes__(self) -> bytes: + return self.buf + + @classmethod + def from_bytes(cls, blob: bytes) -> TestSideEffects: + return cls(blob) + + +async def default_async_callback(interface: StateInterface[TestSideEffects]) -> None: + return None + + +async def default_async_commit(interface: TestSideEffects) -> None: + return None + + +# Test adding a callback +def test_add_callback() -> None: + state_interface = StateInterface({}, TestSideEffects(), True) + initial_count = len(state_interface._new_callbacks) + state_interface.add_callback(default_async_callback) + assert len(state_interface._new_callbacks) == initial_count + 1 + + +# Fixture to create an ActionScope with a mocked DBWrapper2 +@pytest.fixture +async def action_scope() -> AsyncIterator[ActionScope[TestSideEffects]]: + async with ActionScope.new_scope(TestSideEffects) as scope: + yield scope + + +# Test creating a new ActionScope and ensuring tables are created +@pytest.mark.anyio +async def test_new_action_scope(action_scope: ActionScope[TestSideEffects]) -> None: + async with action_scope.use() as interface: + assert interface == StateInterface({}, TestSideEffects(), True) + + +@pytest.mark.anyio +async def test_scope_persistence(action_scope: ActionScope[TestSideEffects]) -> None: + async with action_scope.use() as interface: + interface.memos[b"foo"] = b"bar" + interface.side_effects.buf = b"bar" + + async with action_scope.use() as interface: + assert interface.memos[b"foo"] == b"bar" + assert interface.side_effects.buf == b"bar" + + +@pytest.mark.anyio +async def test_transactionality(action_scope: ActionScope[TestSideEffects]) -> None: + async with action_scope.use() as interface: + interface.memos[b"foo"] = b"bar" + interface.side_effects.buf = b"bar" + + try: + async with action_scope.use() as interface: + interface.memos[b"foo"] = b"qux" + interface.side_effects.buf = b"qat" + raise RuntimeError("Going to be caught") + except RuntimeError: + pass + + async with action_scope.use() as interface: + assert interface.memos[b"foo"] == b"bar" + assert interface.side_effects.buf == b"bar" + + +@pytest.mark.anyio +async def test_callbacks() -> None: + async with ActionScope.new_scope(TestSideEffects) as action_scope: + async with action_scope.use() as interface: + + async def callback(interface: StateInterface[TestSideEffects]) -> None: + interface.side_effects.buf = b"bar" + + interface.add_callback(callback) + + assert action_scope.side_effects.buf == b"bar" + + +@pytest.mark.anyio +async def test_callback_in_callback_error() -> None: + with pytest.raises(ValueError, match="callback"): + async with ActionScope.new_scope(TestSideEffects) as action_scope: + async with action_scope.use() as interface: + + async def callback(interface: StateInterface[TestSideEffects]) -> None: + interface.add_callback(default_async_callback) + + interface.add_callback(callback) + + +@pytest.mark.anyio +async def test_no_callbacks_if_error() -> None: + try: + async with ActionScope.new_scope(TestSideEffects) as action_scope: + async with action_scope.use() as interface: + + async def callback(interface: StateInterface[TestSideEffects]) -> None: + raise NotImplementedError("Should not get here") # pragma: no cover + + interface.add_callback(callback) + + async with action_scope.use() as interface: + raise RuntimeError("This should prevent the callbacks from being called") + except RuntimeError: + pass diff --git a/chia/util/action_scope.py b/chia/util/action_scope.py new file mode 100644 index 000000000000..68a5126a1de4 --- /dev/null +++ b/chia/util/action_scope.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from typing import AsyncIterator, Awaitable, Callable, Dict, Generic, List, Optional, Protocol, Type, TypeVar + +import aiosqlite + +from chia.util.db_wrapper import DBWrapper2, execute_fetchone + + +class ResourceManager(Protocol): + @classmethod + @contextlib.asynccontextmanager + async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: + # We have to put this yield here for mypy to recognize the function as a generator + yield # type: ignore[misc] + + @contextlib.asynccontextmanager + async def use(self) -> AsyncIterator[None]: + # We have to put this yield here for mypy to recognize the function as a generator + yield + + async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: ... + + async def save_resource(self, resource: SideEffects) -> None: ... + + async def get_memos(self) -> Dict[bytes, bytes]: ... + + async def save_memos(self, memos: Dict[bytes, bytes]) -> None: ... + + +@dataclass +class SQLiteResourceManager: + + _db: DBWrapper2 + _active_writer: Optional[aiosqlite.Connection] = field(init=False, default=None) + + def get_active_writer(self) -> aiosqlite.Connection: + if self._active_writer is None: + raise RuntimeError("Can only access resources while under `use()` context manager") + + return self._active_writer + + @classmethod + @contextlib.asynccontextmanager + async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: + async with DBWrapper2.managed(":memory:", reader_count=0) as db: + self = cls(db) + async with self._db.writer() as conn: + await conn.execute("CREATE TABLE memos(" " key blob," " value blob" ")") + await conn.execute("CREATE TABLE side_effects(" " total blob" ")") + await conn.execute( + "INSERT INTO side_effects VALUES(?)", + (bytes(initial_resource),), + ) + yield self + + @contextlib.asynccontextmanager + async def use(self) -> AsyncIterator[None]: + async with self._db.writer() as conn: + self._active_writer = conn + yield + self._active_writer = None + + async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: + row = await execute_fetchone(self.get_active_writer(), "SELECT total FROM side_effects") + assert row is not None + side_effects = resource_type.from_bytes(row[0]) + return side_effects + + async def save_resource(self, resource: SideEffects) -> None: + await self.get_active_writer().execute("DELETE FROM side_effects") + await self.get_active_writer().execute( + "INSERT INTO side_effects VALUES(?)", + (bytes(resource),), + ) + + async def get_memos(self) -> Dict[bytes, bytes]: + rows = await self.get_active_writer().execute_fetchall("SELECT key, value FROM memos") + memos = {row[0]: row[1] for row in rows} + return memos + + async def save_memos(self, memos: Dict[bytes, bytes]) -> None: + await self.get_active_writer().execute("DELETE FROM memos") + await self.get_active_writer().executemany( + "INSERT INTO memos VALUES(?, ?)", + memos.items(), + ) + + +class SideEffects(Protocol): + def __bytes__(self) -> bytes: ... + + @classmethod + def from_bytes(cls: Type[_T_SideEffects], blob: bytes) -> _T_SideEffects: ... + + +_T_SideEffects = TypeVar("_T_SideEffects", bound=SideEffects) + + +@dataclass +class ActionScope(Generic[_T_SideEffects]): + """ + The idea of a wallet action is to map a single user input to many potentially distributed wallet functions and side + effects. The eventual goal is to have this be the only connection between a wallet type and the WSM. + + Utilizes a "resource manager" to hold the state in order to take advantage of rollbacks and prevent concurrent tasks + from interferring with each other. + """ + + _resource_manager: ResourceManager + _side_effects_format: Type[_T_SideEffects] + _callbacks: List[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = field(default_factory=list) + _final_side_effects: Optional[_T_SideEffects] = field(init=False, default=None) + + @property + def side_effects(self) -> _T_SideEffects: + if self._final_side_effects is None: + raise RuntimeError( + "Can only request ActionScope.side_effects after exiting context manager. " + "While in context manager, use ActionScope.use()." + ) + + return self._final_side_effects + + @classmethod + @contextlib.asynccontextmanager + async def new_scope( + cls, + side_effects_format: Type[_T_SideEffects], + resource_manager_backend: Type[ResourceManager] = SQLiteResourceManager, + ) -> AsyncIterator[ActionScope[_T_SideEffects]]: + async with resource_manager_backend.managed(side_effects_format()) as resource_manager: + self = cls(_resource_manager=resource_manager, _side_effects_format=side_effects_format) + try: + yield self + except Exception: + raise + else: + async with self.use(_callbacks_allowed=False) as interface: + for callback in self._callbacks: + await callback(interface) + self._final_side_effects = interface.side_effects + + @contextlib.asynccontextmanager + async def use(self, _callbacks_allowed: bool = True) -> AsyncIterator[StateInterface[_T_SideEffects]]: + async with self._resource_manager.use(): + memos = await self._resource_manager.get_memos() + side_effects = await self._resource_manager.get_resource(self._side_effects_format) + interface = StateInterface(memos, side_effects, _callbacks_allowed) + yield interface + await self._resource_manager.save_memos(interface.memos) + await self._resource_manager.save_resource(interface.side_effects) + self._callbacks.extend(interface._new_callbacks) + + +@dataclass +class StateInterface(Generic[_T_SideEffects]): + memos: Dict[bytes, bytes] + side_effects: _T_SideEffects + _callbacks_allowed: bool + _new_callbacks: List[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = field(default_factory=list) + + def add_callback(self, callback: Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]) -> None: + if not self._callbacks_allowed: + raise ValueError("Cannot create a new callback from within another callback") + self._new_callbacks.append(callback) From 719b880f8637edbec2afaccaa4811a3798817e75 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 07:54:58 -0700 Subject: [PATCH 02/12] Add `WalletActionScope` --- .../_tests/wallet/test_wallet_action_scope.py | 61 +++++++++++++++ chia/wallet/wallet_action_scope.py | 76 +++++++++++++++++++ chia/wallet/wallet_state_manager.py | 53 +++++++++---- 3 files changed, 174 insertions(+), 16 deletions(-) create mode 100644 chia/_tests/wallet/test_wallet_action_scope.py create mode 100644 chia/wallet/wallet_action_scope.py diff --git a/chia/_tests/wallet/test_wallet_action_scope.py b/chia/_tests/wallet/test_wallet_action_scope.py new file mode 100644 index 000000000000..2313bef23bdd --- /dev/null +++ b/chia/_tests/wallet/test_wallet_action_scope.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import pytest + +from chia._tests.cmds.wallet.test_consts import STD_TX +from chia.wallet.signer_protocol import SigningResponse +from chia.wallet.transaction_record import TransactionRecord +from chia.wallet.wallet_action_scope import WalletActionScope, WalletSideEffects + + +def test_back_and_forth_serialization() -> None: + assert bytes(WalletSideEffects()) == b"\x00\x00\x00\x00\x00\x00\x00\x00" + assert WalletSideEffects.from_bytes(bytes(WalletSideEffects())) == WalletSideEffects() + assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX]))) == WalletSideEffects([STD_TX]) + assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX, STD_TX]))) == WalletSideEffects( + [STD_TX, STD_TX] + ) + + +@dataclass +class MockWalletStateManager: + most_recent_call: Optional[Tuple[List[TransactionRecord], bool, bool, bool, List[SigningResponse]]] = None + + async def add_pending_transactions( + self, + txs: List[TransactionRecord], + push: bool, + merge_spends: bool, + sign: bool, + additional_signing_responses: List[SigningResponse], + ) -> List[TransactionRecord]: + self.most_recent_call = (txs, push, merge_spends, sign, additional_signing_responses) + return txs + + +@pytest.mark.anyio +async def test_wallet_action_scope() -> None: + wsm = MockWalletStateManager() + async with WalletActionScope.new( + wsm, push=True, merge_spends=False, sign=True, additional_signing_responses=[] # type: ignore[arg-type] + ) as action_scope: + async with action_scope.use() as interface: + interface.side_effects.transactions = [STD_TX] + + with pytest.raises(RuntimeError): + action_scope.side_effects + + assert action_scope.side_effects.transactions == [STD_TX] + assert wsm.most_recent_call == ([STD_TX], True, False, True, []) + + async with WalletActionScope.new( + wsm, push=False, merge_spends=True, sign=True, additional_signing_responses=[] # type: ignore[arg-type] + ) as action_scope: + async with action_scope.use() as interface: + interface.side_effects.transactions = [] + + assert action_scope.side_effects.transactions == [] + assert wsm.most_recent_call == ([], False, True, True, []) diff --git a/chia/wallet/wallet_action_scope.py b/chia/wallet/wallet_action_scope.py new file mode 100644 index 000000000000..e435c1d90e24 --- /dev/null +++ b/chia/wallet/wallet_action_scope.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from typing import AsyncIterator, List, Optional, cast + +from chia.util.action_scope import ActionScope +from chia.wallet.signer_protocol import SigningResponse +from chia.wallet.transaction_record import TransactionRecord +from chia.wallet.wallet_state_manager import WalletStateManager + + +@dataclass +class WalletSideEffects: + transactions: List[TransactionRecord] = field(default_factory=list) + signing_responses: List[SigningResponse] = field(default_factory=list) + + def __bytes__(self) -> bytes: + blob = b"" + blob += len(self.transactions).to_bytes(4, "big") + for tx in self.transactions: + tx_bytes = bytes(tx) + blob += len(tx_bytes).to_bytes(4, "big") + tx_bytes + blob += len(self.signing_responses).to_bytes(4, "big") + for sr in self.signing_responses: + sr_bytes = bytes(sr) + blob += len(sr_bytes).to_bytes(4, "big") + sr_bytes + return blob + + @classmethod + def from_bytes(cls, blob: bytes) -> WalletSideEffects: + instance = cls() + while blob != b"": + tx_len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + for _ in range(0, tx_len_prefix): + len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + instance.transactions.append(TransactionRecord.from_bytes(blob[:len_prefix])) + blob = blob[len_prefix:] + sr_len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + for _ in range(0, sr_len_prefix): + len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + instance.signing_responses.append(SigningResponse.from_bytes(blob[:len_prefix])) + blob = blob[len_prefix:] + + return instance + + +@dataclass +class WalletActionScope(ActionScope[WalletSideEffects]): + @classmethod + @contextlib.asynccontextmanager + async def new( + cls, + wallet_state_manager: WalletStateManager, + push: bool = False, + merge_spends: bool = True, + sign: Optional[bool] = None, + additional_signing_responses: List[SigningResponse] = [], + ) -> AsyncIterator[WalletActionScope]: + async with cls.new_scope(WalletSideEffects) as self: + self = cast(WalletActionScope, self) + async with self.use() as interface: + interface.side_effects.signing_responses = additional_signing_responses.copy() + yield self + + self.side_effects.transactions = await wallet_state_manager.add_pending_transactions( + self.side_effects.transactions, + push=push, + merge_spends=merge_spends, + sign=sign, + additional_signing_responses=self.side_effects.signing_responses, + ) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index 9d2a0963322f..190d0a9c0519 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import dataclasses import logging import multiprocessing.context @@ -144,6 +145,7 @@ from chia.wallet.vc_wallet.vc_store import VCStore from chia.wallet.vc_wallet.vc_wallet import VCWallet from chia.wallet.wallet import Wallet +from chia.wallet.wallet_action_scope import WalletActionScope from chia.wallet.wallet_blockchain import WalletBlockchain from chia.wallet.wallet_coin_record import MetadataTypes, WalletCoinRecord from chia.wallet.wallet_coin_store import WalletCoinStore @@ -2255,9 +2257,10 @@ async def coin_added( async def add_pending_transactions( self, tx_records: List[TransactionRecord], + push: bool = True, merge_spends: bool = True, sign: Optional[bool] = None, - additional_signing_responses: List[SigningResponse] = [], + additional_signing_responses: Optional[List[SigningResponse]] = None, ) -> List[TransactionRecord]: """ Add a list of transactions to be submitted to the full node. @@ -2279,26 +2282,27 @@ async def add_pending_transactions( for i, tx in enumerate(tx_records) ] if sign: - tx_records, _ = await self.sign_transactions( + tx_records, signing_responses = await self.sign_transactions( tx_records, - additional_signing_responses, + [] if additional_signing_responses is None else additional_signing_responses, additional_signing_responses != [], ) - all_coins_names = [] - async with self.db_wrapper.writer_maybe_transaction(): - for tx_record in tx_records: - # Wallet node will use this queue to retry sending this transaction until full nodes receives it - await self.tx_store.add_transaction_record(tx_record) - all_coins_names.extend([coin.name() for coin in tx_record.additions]) - all_coins_names.extend([coin.name() for coin in tx_record.removals]) + if push: + all_coins_names = [] + async with self.db_wrapper.writer_maybe_transaction(): + for tx_record in tx_records: + # Wallet node will use this queue to retry sending this transaction until full nodes receives it + await self.tx_store.add_transaction_record(tx_record) + all_coins_names.extend([coin.name() for coin in tx_record.additions]) + all_coins_names.extend([coin.name() for coin in tx_record.removals]) - await self.add_interested_coin_ids(all_coins_names) + await self.add_interested_coin_ids(all_coins_names) - if actual_spend_involved: - self.tx_pending_changed() - for wallet_id in {tx.wallet_id for tx in tx_records}: - self.state_changed("pending_transaction", wallet_id) - await self.wallet_node.update_ui() + if actual_spend_involved: + self.tx_pending_changed() + for wallet_id in {tx.wallet_id for tx in tx_records}: + self.state_changed("pending_transaction", wallet_id) + await self.wallet_node.update_ui() return tx_records @@ -2739,3 +2743,20 @@ async def submit_transactions(self, signed_txs: List[SignedTransaction]) -> List for bundle in bundles: await self.wallet_node.push_tx(bundle) return [bundle.name() for bundle in bundles] + + @contextlib.asynccontextmanager + async def new_action_scope( + self, + push: bool = False, + merge_spends: bool = True, + sign: Optional[bool] = None, + additional_signing_responses: List[SigningResponse] = [], + ) -> AsyncIterator[WalletActionScope]: + async with WalletActionScope.new( + self, + push=push, + merge_spends=merge_spends, + sign=sign, + additional_signing_responses=additional_signing_responses, + ) as action_scope: + yield action_scope From 3abb7bac435aecaac41840db7e10f6d83ab9767e Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 07:53:21 -0700 Subject: [PATCH 03/12] Add the concept of 'action scopes' --- chia/_tests/util/test_action_scope.py | 122 +++++++++++++++++++ chia/util/action_scope.py | 168 ++++++++++++++++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 chia/_tests/util/test_action_scope.py create mode 100644 chia/util/action_scope.py diff --git a/chia/_tests/util/test_action_scope.py b/chia/_tests/util/test_action_scope.py new file mode 100644 index 000000000000..a00a5c4f710a --- /dev/null +++ b/chia/_tests/util/test_action_scope.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import AsyncIterator + +import pytest + +from chia.util.action_scope import ActionScope, StateInterface + + +@dataclass +class TestSideEffects: + buf: bytes = b"" + + def __bytes__(self) -> bytes: + return self.buf + + @classmethod + def from_bytes(cls, blob: bytes) -> TestSideEffects: + return cls(blob) + + +async def default_async_callback(interface: StateInterface[TestSideEffects]) -> None: + return None + + +async def default_async_commit(interface: TestSideEffects) -> None: + return None + + +# Test adding a callback +def test_add_callback() -> None: + state_interface = StateInterface({}, TestSideEffects(), True) + initial_count = len(state_interface._new_callbacks) + state_interface.add_callback(default_async_callback) + assert len(state_interface._new_callbacks) == initial_count + 1 + + +# Fixture to create an ActionScope with a mocked DBWrapper2 +@pytest.fixture +async def action_scope() -> AsyncIterator[ActionScope[TestSideEffects]]: + async with ActionScope.new_scope(TestSideEffects) as scope: + yield scope + + +# Test creating a new ActionScope and ensuring tables are created +@pytest.mark.anyio +async def test_new_action_scope(action_scope: ActionScope[TestSideEffects]) -> None: + async with action_scope.use() as interface: + assert interface == StateInterface({}, TestSideEffects(), True) + + +@pytest.mark.anyio +async def test_scope_persistence(action_scope: ActionScope[TestSideEffects]) -> None: + async with action_scope.use() as interface: + interface.memos[b"foo"] = b"bar" + interface.side_effects.buf = b"bar" + + async with action_scope.use() as interface: + assert interface.memos[b"foo"] == b"bar" + assert interface.side_effects.buf == b"bar" + + +@pytest.mark.anyio +async def test_transactionality(action_scope: ActionScope[TestSideEffects]) -> None: + async with action_scope.use() as interface: + interface.memos[b"foo"] = b"bar" + interface.side_effects.buf = b"bar" + + try: + async with action_scope.use() as interface: + interface.memos[b"foo"] = b"qux" + interface.side_effects.buf = b"qat" + raise RuntimeError("Going to be caught") + except RuntimeError: + pass + + async with action_scope.use() as interface: + assert interface.memos[b"foo"] == b"bar" + assert interface.side_effects.buf == b"bar" + + +@pytest.mark.anyio +async def test_callbacks() -> None: + async with ActionScope.new_scope(TestSideEffects) as action_scope: + async with action_scope.use() as interface: + + async def callback(interface: StateInterface[TestSideEffects]) -> None: + interface.side_effects.buf = b"bar" + + interface.add_callback(callback) + + assert action_scope.side_effects.buf == b"bar" + + +@pytest.mark.anyio +async def test_callback_in_callback_error() -> None: + with pytest.raises(ValueError, match="callback"): + async with ActionScope.new_scope(TestSideEffects) as action_scope: + async with action_scope.use() as interface: + + async def callback(interface: StateInterface[TestSideEffects]) -> None: + interface.add_callback(default_async_callback) + + interface.add_callback(callback) + + +@pytest.mark.anyio +async def test_no_callbacks_if_error() -> None: + try: + async with ActionScope.new_scope(TestSideEffects) as action_scope: + async with action_scope.use() as interface: + + async def callback(interface: StateInterface[TestSideEffects]) -> None: + raise NotImplementedError("Should not get here") # pragma: no cover + + interface.add_callback(callback) + + async with action_scope.use() as interface: + raise RuntimeError("This should prevent the callbacks from being called") + except RuntimeError: + pass diff --git a/chia/util/action_scope.py b/chia/util/action_scope.py new file mode 100644 index 000000000000..68a5126a1de4 --- /dev/null +++ b/chia/util/action_scope.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from typing import AsyncIterator, Awaitable, Callable, Dict, Generic, List, Optional, Protocol, Type, TypeVar + +import aiosqlite + +from chia.util.db_wrapper import DBWrapper2, execute_fetchone + + +class ResourceManager(Protocol): + @classmethod + @contextlib.asynccontextmanager + async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: + # We have to put this yield here for mypy to recognize the function as a generator + yield # type: ignore[misc] + + @contextlib.asynccontextmanager + async def use(self) -> AsyncIterator[None]: + # We have to put this yield here for mypy to recognize the function as a generator + yield + + async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: ... + + async def save_resource(self, resource: SideEffects) -> None: ... + + async def get_memos(self) -> Dict[bytes, bytes]: ... + + async def save_memos(self, memos: Dict[bytes, bytes]) -> None: ... + + +@dataclass +class SQLiteResourceManager: + + _db: DBWrapper2 + _active_writer: Optional[aiosqlite.Connection] = field(init=False, default=None) + + def get_active_writer(self) -> aiosqlite.Connection: + if self._active_writer is None: + raise RuntimeError("Can only access resources while under `use()` context manager") + + return self._active_writer + + @classmethod + @contextlib.asynccontextmanager + async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: + async with DBWrapper2.managed(":memory:", reader_count=0) as db: + self = cls(db) + async with self._db.writer() as conn: + await conn.execute("CREATE TABLE memos(" " key blob," " value blob" ")") + await conn.execute("CREATE TABLE side_effects(" " total blob" ")") + await conn.execute( + "INSERT INTO side_effects VALUES(?)", + (bytes(initial_resource),), + ) + yield self + + @contextlib.asynccontextmanager + async def use(self) -> AsyncIterator[None]: + async with self._db.writer() as conn: + self._active_writer = conn + yield + self._active_writer = None + + async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: + row = await execute_fetchone(self.get_active_writer(), "SELECT total FROM side_effects") + assert row is not None + side_effects = resource_type.from_bytes(row[0]) + return side_effects + + async def save_resource(self, resource: SideEffects) -> None: + await self.get_active_writer().execute("DELETE FROM side_effects") + await self.get_active_writer().execute( + "INSERT INTO side_effects VALUES(?)", + (bytes(resource),), + ) + + async def get_memos(self) -> Dict[bytes, bytes]: + rows = await self.get_active_writer().execute_fetchall("SELECT key, value FROM memos") + memos = {row[0]: row[1] for row in rows} + return memos + + async def save_memos(self, memos: Dict[bytes, bytes]) -> None: + await self.get_active_writer().execute("DELETE FROM memos") + await self.get_active_writer().executemany( + "INSERT INTO memos VALUES(?, ?)", + memos.items(), + ) + + +class SideEffects(Protocol): + def __bytes__(self) -> bytes: ... + + @classmethod + def from_bytes(cls: Type[_T_SideEffects], blob: bytes) -> _T_SideEffects: ... + + +_T_SideEffects = TypeVar("_T_SideEffects", bound=SideEffects) + + +@dataclass +class ActionScope(Generic[_T_SideEffects]): + """ + The idea of a wallet action is to map a single user input to many potentially distributed wallet functions and side + effects. The eventual goal is to have this be the only connection between a wallet type and the WSM. + + Utilizes a "resource manager" to hold the state in order to take advantage of rollbacks and prevent concurrent tasks + from interferring with each other. + """ + + _resource_manager: ResourceManager + _side_effects_format: Type[_T_SideEffects] + _callbacks: List[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = field(default_factory=list) + _final_side_effects: Optional[_T_SideEffects] = field(init=False, default=None) + + @property + def side_effects(self) -> _T_SideEffects: + if self._final_side_effects is None: + raise RuntimeError( + "Can only request ActionScope.side_effects after exiting context manager. " + "While in context manager, use ActionScope.use()." + ) + + return self._final_side_effects + + @classmethod + @contextlib.asynccontextmanager + async def new_scope( + cls, + side_effects_format: Type[_T_SideEffects], + resource_manager_backend: Type[ResourceManager] = SQLiteResourceManager, + ) -> AsyncIterator[ActionScope[_T_SideEffects]]: + async with resource_manager_backend.managed(side_effects_format()) as resource_manager: + self = cls(_resource_manager=resource_manager, _side_effects_format=side_effects_format) + try: + yield self + except Exception: + raise + else: + async with self.use(_callbacks_allowed=False) as interface: + for callback in self._callbacks: + await callback(interface) + self._final_side_effects = interface.side_effects + + @contextlib.asynccontextmanager + async def use(self, _callbacks_allowed: bool = True) -> AsyncIterator[StateInterface[_T_SideEffects]]: + async with self._resource_manager.use(): + memos = await self._resource_manager.get_memos() + side_effects = await self._resource_manager.get_resource(self._side_effects_format) + interface = StateInterface(memos, side_effects, _callbacks_allowed) + yield interface + await self._resource_manager.save_memos(interface.memos) + await self._resource_manager.save_resource(interface.side_effects) + self._callbacks.extend(interface._new_callbacks) + + +@dataclass +class StateInterface(Generic[_T_SideEffects]): + memos: Dict[bytes, bytes] + side_effects: _T_SideEffects + _callbacks_allowed: bool + _new_callbacks: List[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = field(default_factory=list) + + def add_callback(self, callback: Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]) -> None: + if not self._callbacks_allowed: + raise ValueError("Cannot create a new callback from within another callback") + self._new_callbacks.append(callback) From 56230c696c37eb3890cbbc28bcb923a5558c5dbc Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 10:06:44 -0700 Subject: [PATCH 04/12] pylint and test coverage --- chia/_tests/util/test_action_scope.py | 6 +----- chia/util/action_scope.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/chia/_tests/util/test_action_scope.py b/chia/_tests/util/test_action_scope.py index a00a5c4f710a..9f44c4a05cf9 100644 --- a/chia/_tests/util/test_action_scope.py +++ b/chia/_tests/util/test_action_scope.py @@ -21,11 +21,7 @@ def from_bytes(cls, blob: bytes) -> TestSideEffects: async def default_async_callback(interface: StateInterface[TestSideEffects]) -> None: - return None - - -async def default_async_commit(interface: TestSideEffects) -> None: - return None + return None # pragma: no cover # Test adding a callback diff --git a/chia/util/action_scope.py b/chia/util/action_scope.py index 68a5126a1de4..8e15f4046bea 100644 --- a/chia/util/action_scope.py +++ b/chia/util/action_scope.py @@ -12,12 +12,12 @@ class ResourceManager(Protocol): @classmethod @contextlib.asynccontextmanager - async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: + async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: # pragma: no cover # We have to put this yield here for mypy to recognize the function as a generator yield # type: ignore[misc] @contextlib.asynccontextmanager - async def use(self) -> AsyncIterator[None]: + async def use(self) -> AsyncIterator[None]: # pragma: no cover # We have to put this yield here for mypy to recognize the function as a generator yield @@ -149,10 +149,14 @@ async def use(self, _callbacks_allowed: bool = True) -> AsyncIterator[StateInter memos = await self._resource_manager.get_memos() side_effects = await self._resource_manager.get_resource(self._side_effects_format) interface = StateInterface(memos, side_effects, _callbacks_allowed) - yield interface - await self._resource_manager.save_memos(interface.memos) - await self._resource_manager.save_resource(interface.side_effects) - self._callbacks.extend(interface._new_callbacks) + try: + yield interface + except Exception: + raise + else: + await self._resource_manager.save_memos(interface.memos) + await self._resource_manager.save_resource(interface.side_effects) + self._callbacks.extend(interface._new_callbacks) @dataclass From da718ec56663fff934c50944d9d107681aa9b19c Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 10:14:28 -0700 Subject: [PATCH 05/12] add try/finally --- chia/util/action_scope.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chia/util/action_scope.py b/chia/util/action_scope.py index 8e15f4046bea..6d23852efe02 100644 --- a/chia/util/action_scope.py +++ b/chia/util/action_scope.py @@ -60,8 +60,10 @@ async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceM async def use(self) -> AsyncIterator[None]: async with self._db.writer() as conn: self._active_writer = conn - yield - self._active_writer = None + try: + yield + finally: + self._active_writer = None async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: row = await execute_fetchone(self.get_active_writer(), "SELECT total FROM side_effects") From 8acb86c7c9edc47bf0dd898993e6cad5e05fd228 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 12:08:53 -0700 Subject: [PATCH 06/12] add try/except --- chia/wallet/wallet_action_scope.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chia/wallet/wallet_action_scope.py b/chia/wallet/wallet_action_scope.py index 47ffe22b5878..1d7f770be887 100644 --- a/chia/wallet/wallet_action_scope.py +++ b/chia/wallet/wallet_action_scope.py @@ -68,7 +68,10 @@ async def new( self = cast(WalletActionScope, self) async with self.use() as interface: interface.side_effects.signing_responses = additional_signing_responses.copy() - yield self + try: + yield self + except Exception: + raise self.side_effects.transactions = await wallet_state_manager.add_pending_transactions( self.side_effects.transactions, From 2952287939aabcb29cec5980723b4a41648aa374 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 12:13:33 -0700 Subject: [PATCH 07/12] Undo giving a variable a name --- chia/wallet/wallet_state_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index 190d0a9c0519..c4c6c307bd9a 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -2282,7 +2282,7 @@ async def add_pending_transactions( for i, tx in enumerate(tx_records) ] if sign: - tx_records, signing_responses = await self.sign_transactions( + tx_records, _ = await self.sign_transactions( tx_records, [] if additional_signing_responses is None else additional_signing_responses, additional_signing_responses != [], From f3482b5f5b1514d4df9611178e6be26cd758c463 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 13:26:07 -0700 Subject: [PATCH 08/12] Test coverage --- chia/_tests/wallet/test_wallet_action_scope.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/chia/_tests/wallet/test_wallet_action_scope.py b/chia/_tests/wallet/test_wallet_action_scope.py index 2313bef23bdd..144e9c857a39 100644 --- a/chia/_tests/wallet/test_wallet_action_scope.py +++ b/chia/_tests/wallet/test_wallet_action_scope.py @@ -6,18 +6,23 @@ import pytest from chia._tests.cmds.wallet.test_consts import STD_TX +from chia.types.blockchain_format.sized_bytes import bytes32 from chia.wallet.signer_protocol import SigningResponse from chia.wallet.transaction_record import TransactionRecord from chia.wallet.wallet_action_scope import WalletActionScope, WalletSideEffects +MOCK_SR = SigningResponse(b"hey", bytes32([0] * 32)) + def test_back_and_forth_serialization() -> None: assert bytes(WalletSideEffects()) == b"\x00\x00\x00\x00\x00\x00\x00\x00" assert WalletSideEffects.from_bytes(bytes(WalletSideEffects())) == WalletSideEffects() - assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX]))) == WalletSideEffects([STD_TX]) - assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX, STD_TX]))) == WalletSideEffects( - [STD_TX, STD_TX] + assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX], [MOCK_SR]))) == WalletSideEffects( + [STD_TX], [MOCK_SR] ) + assert WalletSideEffects.from_bytes( + bytes(WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR])) + ) == WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR]) @dataclass From 7d4873d8e20c9430586cfa0b805eb3b5fd4791cb Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 10 Jun 2024 08:53:33 -0700 Subject: [PATCH 09/12] Ban partial sigining in another scenario --- chia/wallet/wallet_state_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index c4c6c307bd9a..bac992e02808 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -2285,7 +2285,7 @@ async def add_pending_transactions( tx_records, _ = await self.sign_transactions( tx_records, [] if additional_signing_responses is None else additional_signing_responses, - additional_signing_responses != [], + additional_signing_responses != [] and additional_signing_responses is not None, ) if push: all_coins_names = [] From 2d8aa490fea4842b182d9cef1f34d92d7dbe6419 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 11 Jun 2024 08:47:45 -0700 Subject: [PATCH 10/12] Make WalletActionScope an alias instead --- .../_tests/wallet/test_wallet_action_scope.py | 14 +++-- chia/wallet/wallet_action_scope.py | 51 +++++++++---------- chia/wallet/wallet_state_manager.py | 4 +- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/chia/_tests/wallet/test_wallet_action_scope.py b/chia/_tests/wallet/test_wallet_action_scope.py index 144e9c857a39..a1099c3f6d66 100644 --- a/chia/_tests/wallet/test_wallet_action_scope.py +++ b/chia/_tests/wallet/test_wallet_action_scope.py @@ -9,7 +9,8 @@ from chia.types.blockchain_format.sized_bytes import bytes32 from chia.wallet.signer_protocol import SigningResponse from chia.wallet.transaction_record import TransactionRecord -from chia.wallet.wallet_action_scope import WalletActionScope, WalletSideEffects +from chia.wallet.wallet_action_scope import WalletSideEffects +from chia.wallet.wallet_state_manager import WalletStateManager MOCK_SR = SigningResponse(b"hey", bytes32([0] * 32)) @@ -41,11 +42,14 @@ async def add_pending_transactions( return txs +MockWalletStateManager.new_action_scope = WalletStateManager.new_action_scope # type: ignore[attr-defined] + + @pytest.mark.anyio async def test_wallet_action_scope() -> None: wsm = MockWalletStateManager() - async with WalletActionScope.new( - wsm, push=True, merge_spends=False, sign=True, additional_signing_responses=[] # type: ignore[arg-type] + async with wsm.new_action_scope( # type: ignore[attr-defined] + push=True, merge_spends=False, sign=True, additional_signing_responses=[] ) as action_scope: async with action_scope.use() as interface: interface.side_effects.transactions = [STD_TX] @@ -56,8 +60,8 @@ async def test_wallet_action_scope() -> None: assert action_scope.side_effects.transactions == [STD_TX] assert wsm.most_recent_call == ([STD_TX], True, False, True, []) - async with WalletActionScope.new( - wsm, push=False, merge_spends=True, sign=True, additional_signing_responses=[] # type: ignore[arg-type] + async with wsm.new_action_scope( # type: ignore[attr-defined] + push=False, merge_spends=True, sign=True, additional_signing_responses=[] ) as action_scope: async with action_scope.use() as interface: interface.side_effects.transactions = [] diff --git a/chia/wallet/wallet_action_scope.py b/chia/wallet/wallet_action_scope.py index 1d7f770be887..68db2a39932c 100644 --- a/chia/wallet/wallet_action_scope.py +++ b/chia/wallet/wallet_action_scope.py @@ -52,31 +52,28 @@ def from_bytes(cls, blob: bytes) -> WalletSideEffects: return instance -@dataclass -class WalletActionScope(ActionScope[WalletSideEffects]): - @classmethod - @contextlib.asynccontextmanager - async def new( - cls, - wallet_state_manager: WalletStateManager, - push: bool = False, - merge_spends: bool = True, - sign: Optional[bool] = None, - additional_signing_responses: List[SigningResponse] = [], - ) -> AsyncIterator[WalletActionScope]: - async with cls.new_scope(WalletSideEffects) as self: - self = cast(WalletActionScope, self) - async with self.use() as interface: - interface.side_effects.signing_responses = additional_signing_responses.copy() - try: - yield self - except Exception: - raise +WalletActionScope = ActionScope[WalletSideEffects] + + +@contextlib.asynccontextmanager +async def new_wallet_action_scope( + wallet_state_manager: WalletStateManager, + push: bool = False, + merge_spends: bool = True, + sign: Optional[bool] = None, + additional_signing_responses: List[SigningResponse] = [], +) -> AsyncIterator[WalletActionScope]: + async with ActionScope.new_scope(WalletSideEffects) as self: + self = cast(WalletActionScope, self) + async with self.use() as interface: + interface.side_effects.signing_responses = additional_signing_responses.copy() + + yield self - self.side_effects.transactions = await wallet_state_manager.add_pending_transactions( - self.side_effects.transactions, - push=push, - merge_spends=merge_spends, - sign=sign, - additional_signing_responses=self.side_effects.signing_responses, - ) + self.side_effects.transactions = await wallet_state_manager.add_pending_transactions( + self.side_effects.transactions, + push=push, + merge_spends=merge_spends, + sign=sign, + additional_signing_responses=self.side_effects.signing_responses, + ) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index bac992e02808..4f571121037a 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -145,7 +145,7 @@ from chia.wallet.vc_wallet.vc_store import VCStore from chia.wallet.vc_wallet.vc_wallet import VCWallet from chia.wallet.wallet import Wallet -from chia.wallet.wallet_action_scope import WalletActionScope +from chia.wallet.wallet_action_scope import WalletActionScope, new_wallet_action_scope from chia.wallet.wallet_blockchain import WalletBlockchain from chia.wallet.wallet_coin_record import MetadataTypes, WalletCoinRecord from chia.wallet.wallet_coin_store import WalletCoinStore @@ -2752,7 +2752,7 @@ async def new_action_scope( sign: Optional[bool] = None, additional_signing_responses: List[SigningResponse] = [], ) -> AsyncIterator[WalletActionScope]: - async with WalletActionScope.new( + async with new_wallet_action_scope( self, push=push, merge_spends=merge_spends, From 802f2e742c645cf9579d2bc41ad41b7fb789a827 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 11 Jun 2024 10:46:43 -0700 Subject: [PATCH 11/12] Add extra_spends to the action scope flow --- .../_tests/wallet/test_wallet_action_scope.py | 30 ++++++++++++------- chia/wallet/wallet_action_scope.py | 16 ++++++++++ chia/wallet/wallet_state_manager.py | 16 ++++++++++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/chia/_tests/wallet/test_wallet_action_scope.py b/chia/_tests/wallet/test_wallet_action_scope.py index a1099c3f6d66..54583e96bc9c 100644 --- a/chia/_tests/wallet/test_wallet_action_scope.py +++ b/chia/_tests/wallet/test_wallet_action_scope.py @@ -4,31 +4,36 @@ from typing import List, Optional, Tuple import pytest +from chia_rs import G2Element from chia._tests.cmds.wallet.test_consts import STD_TX from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.spend_bundle import SpendBundle from chia.wallet.signer_protocol import SigningResponse from chia.wallet.transaction_record import TransactionRecord from chia.wallet.wallet_action_scope import WalletSideEffects from chia.wallet.wallet_state_manager import WalletStateManager MOCK_SR = SigningResponse(b"hey", bytes32([0] * 32)) +MOCK_SB = SpendBundle([], G2Element()) def test_back_and_forth_serialization() -> None: - assert bytes(WalletSideEffects()) == b"\x00\x00\x00\x00\x00\x00\x00\x00" + assert bytes(WalletSideEffects()) == b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" assert WalletSideEffects.from_bytes(bytes(WalletSideEffects())) == WalletSideEffects() - assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX], [MOCK_SR]))) == WalletSideEffects( - [STD_TX], [MOCK_SR] + assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX], [MOCK_SR], [MOCK_SB]))) == WalletSideEffects( + [STD_TX], [MOCK_SR], [MOCK_SB] ) assert WalletSideEffects.from_bytes( - bytes(WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR])) - ) == WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR]) + bytes(WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB])) + ) == WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB]) @dataclass class MockWalletStateManager: - most_recent_call: Optional[Tuple[List[TransactionRecord], bool, bool, bool, List[SigningResponse]]] = None + most_recent_call: Optional[ + Tuple[List[TransactionRecord], bool, bool, bool, List[SigningResponse], List[SpendBundle]] + ] = None async def add_pending_transactions( self, @@ -37,8 +42,9 @@ async def add_pending_transactions( merge_spends: bool, sign: bool, additional_signing_responses: List[SigningResponse], + extra_spends: List[SpendBundle], ) -> List[TransactionRecord]: - self.most_recent_call = (txs, push, merge_spends, sign, additional_signing_responses) + self.most_recent_call = (txs, push, merge_spends, sign, additional_signing_responses, extra_spends) return txs @@ -49,7 +55,11 @@ async def add_pending_transactions( async def test_wallet_action_scope() -> None: wsm = MockWalletStateManager() async with wsm.new_action_scope( # type: ignore[attr-defined] - push=True, merge_spends=False, sign=True, additional_signing_responses=[] + push=True, + merge_spends=False, + sign=True, + additional_signing_responses=[], + extra_spends=[], ) as action_scope: async with action_scope.use() as interface: interface.side_effects.transactions = [STD_TX] @@ -58,7 +68,7 @@ async def test_wallet_action_scope() -> None: action_scope.side_effects assert action_scope.side_effects.transactions == [STD_TX] - assert wsm.most_recent_call == ([STD_TX], True, False, True, []) + assert wsm.most_recent_call == ([STD_TX], True, False, True, [], []) async with wsm.new_action_scope( # type: ignore[attr-defined] push=False, merge_spends=True, sign=True, additional_signing_responses=[] @@ -67,4 +77,4 @@ async def test_wallet_action_scope() -> None: interface.side_effects.transactions = [] assert action_scope.side_effects.transactions == [] - assert wsm.most_recent_call == ([], False, True, True, []) + assert wsm.most_recent_call == ([], False, True, True, [], []) diff --git a/chia/wallet/wallet_action_scope.py b/chia/wallet/wallet_action_scope.py index 68db2a39932c..85f4cb759b8f 100644 --- a/chia/wallet/wallet_action_scope.py +++ b/chia/wallet/wallet_action_scope.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, AsyncIterator, List, Optional, cast +from chia.types.spend_bundle import SpendBundle from chia.util.action_scope import ActionScope from chia.wallet.signer_protocol import SigningResponse from chia.wallet.transaction_record import TransactionRecord @@ -17,6 +18,7 @@ class WalletSideEffects: transactions: List[TransactionRecord] = field(default_factory=list) signing_responses: List[SigningResponse] = field(default_factory=list) + extra_spends: List[SpendBundle] = field(default_factory=list) def __bytes__(self) -> bytes: blob = b"" @@ -28,6 +30,10 @@ def __bytes__(self) -> bytes: for sr in self.signing_responses: sr_bytes = bytes(sr) blob += len(sr_bytes).to_bytes(4, "big") + sr_bytes + blob += len(self.extra_spends).to_bytes(4, "big") + for sb in self.extra_spends: + sb_bytes = bytes(sb) + blob += len(sb_bytes).to_bytes(4, "big") + sb_bytes return blob @classmethod @@ -48,6 +54,13 @@ def from_bytes(cls, blob: bytes) -> WalletSideEffects: blob = blob[4:] instance.signing_responses.append(SigningResponse.from_bytes(blob[:len_prefix])) blob = blob[len_prefix:] + sb_len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + for _ in range(0, sb_len_prefix): + len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + instance.extra_spends.append(SpendBundle.from_bytes(blob[:len_prefix])) + blob = blob[len_prefix:] return instance @@ -62,11 +75,13 @@ async def new_wallet_action_scope( merge_spends: bool = True, sign: Optional[bool] = None, additional_signing_responses: List[SigningResponse] = [], + extra_spends: List[SpendBundle] = [], ) -> AsyncIterator[WalletActionScope]: async with ActionScope.new_scope(WalletSideEffects) as self: self = cast(WalletActionScope, self) async with self.use() as interface: interface.side_effects.signing_responses = additional_signing_responses.copy() + interface.side_effects.extra_spends = extra_spends.copy() yield self @@ -76,4 +91,5 @@ async def new_wallet_action_scope( merge_spends=merge_spends, sign=sign, additional_signing_responses=self.side_effects.signing_responses, + extra_spends=self.side_effects.extra_spends, ) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index 4f571121037a..dcfd4c4971db 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -2261,6 +2261,7 @@ async def add_pending_transactions( merge_spends: bool = True, sign: Optional[bool] = None, additional_signing_responses: Optional[List[SigningResponse]] = None, + extra_spends: Optional[List[SpendBundle]] = None, ) -> List[TransactionRecord]: """ Add a list of transactions to be submitted to the full node. @@ -2271,6 +2272,8 @@ async def add_pending_transactions( agg_spend: SpendBundle = SpendBundle.aggregate( [tx.spend_bundle for tx in tx_records if tx.spend_bundle is not None] ) + if extra_spends is not None: + agg_spend = SpendBundle.aggregate([agg_spend, *extra_spends]) actual_spend_involved: bool = agg_spend != SpendBundle([], G2Element()) if merge_spends and actual_spend_involved: tx_records = [ @@ -2281,6 +2284,17 @@ async def add_pending_transactions( ) for i, tx in enumerate(tx_records) ] + elif extra_spends is not None and extra_spends != []: + extra_spends.extend([] if tx_records[0].spend_bundle is None else [tx_records[0].spend_bundle]) + extra_spend_bundle = SpendBundle.aggregate(extra_spends) + tx_records = [ + dataclasses.replace( + tx, + spend_bundle=extra_spend_bundle if i == 0 else None, + name=extra_spend_bundle.name() if i == 0 else bytes32.secret(), + ) + for i, tx in enumerate(tx_records) + ] if sign: tx_records, _ = await self.sign_transactions( tx_records, @@ -2751,6 +2765,7 @@ async def new_action_scope( merge_spends: bool = True, sign: Optional[bool] = None, additional_signing_responses: List[SigningResponse] = [], + extra_spends: List[SpendBundle] = [], ) -> AsyncIterator[WalletActionScope]: async with new_wallet_action_scope( self, @@ -2758,5 +2773,6 @@ async def new_action_scope( merge_spends=merge_spends, sign=sign, additional_signing_responses=additional_signing_responses, + extra_spends=extra_spends, ) as action_scope: yield action_scope From 933e81406d633a2fb0bf062fc3a85b2fad8f9dbd Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 18 Jun 2024 13:00:27 -0700 Subject: [PATCH 12/12] Add test for .add_pending_transactions --- .../wallet/test_wallet_state_manager.py | 111 +++++++++++++++++- chia/wallet/wallet_state_manager.py | 2 +- 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/chia/_tests/wallet/test_wallet_state_manager.py b/chia/_tests/wallet/test_wallet_state_manager.py index 06ee7bf518e5..826db346e8d3 100644 --- a/chia/_tests/wallet/test_wallet_state_manager.py +++ b/chia/_tests/wallet/test_wallet_state_manager.py @@ -1,19 +1,26 @@ from __future__ import annotations from contextlib import asynccontextmanager -from typing import AsyncIterator +from typing import AsyncIterator, List import pytest +from chia_rs import G2Element +from chia._tests.environments.wallet import WalletTestFramework from chia._tests.util.setup_nodes import OldSimulatorsAndWallets from chia.protocols.wallet_protocol import CoinState from chia.server.outbound_message import NodeType from chia.types.blockchain_format.coin import Coin +from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.coin_spend import make_spend from chia.types.peer_info import PeerInfo +from chia.types.spend_bundle import SpendBundle from chia.util.ints import uint32, uint64 from chia.wallet.derivation_record import DerivationRecord from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened +from chia.wallet.transaction_record import TransactionRecord +from chia.wallet.util.transaction_type import TransactionType from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_state_manager import WalletStateManager @@ -95,3 +102,105 @@ async def test_determine_coin_type(simulator_and_wallet: OldSimulatorsAndWallets assert (None, None) == await wallet_state_manager.determine_coin_type( peer, CoinState(Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), uint32(0), uint32(0)), None ) + + +@pytest.mark.parametrize( + "wallet_environments", + [{"num_environments": 1, "blocks_needed": [1], "trusted": True, "reuse_puzhash": True}], + indirect=True, +) +@pytest.mark.limit_consensus_modes(reason="irrelevant") +@pytest.mark.anyio +async def test_commit_transactions_to_db(wallet_environments: WalletTestFramework) -> None: + env = wallet_environments.environments[0] + wsm = env.wallet_state_manager + + coins = list( + await wsm.main_wallet.select_coins( + uint64(2_000_000_000_000), coin_selection_config=wallet_environments.tx_config.coin_selection_config + ) + ) + [tx1] = await wsm.main_wallet.generate_signed_transaction( + uint64(0), + bytes32([0] * 32), + wallet_environments.tx_config, + coins={coins[0]}, + ) + [tx2] = await wsm.main_wallet.generate_signed_transaction( + uint64(0), + bytes32([0] * 32), + wallet_environments.tx_config, + coins={coins[1]}, + ) + + def flatten_spend_bundles(txs: List[TransactionRecord]) -> List[SpendBundle]: + return [tx.spend_bundle for tx in txs if tx.spend_bundle is not None] + + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + new_txs = await wsm.add_pending_transactions( + [tx1, tx2], + push=False, + merge_spends=False, + sign=False, + extra_spends=[], + ) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 2 + for bundle in bundles: + assert bundle.aggregated_signature == G2Element() + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + + extra_coin_spend = make_spend( + Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), Program.to(1), Program.to([None]) + ) + extra_spend = SpendBundle([extra_coin_spend], G2Element()) + + new_txs = await wsm.add_pending_transactions( + [tx1, tx2], + push=False, + merge_spends=False, + sign=False, + extra_spends=[extra_spend], + ) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 2 + for bundle in bundles: + assert bundle.aggregated_signature == G2Element() + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends] + + new_txs = await wsm.add_pending_transactions( + [tx1, tx2], + push=False, + merge_spends=True, + sign=False, + extra_spends=[extra_spend], + ) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 1 + for bundle in bundles: + assert bundle.aggregated_signature == G2Element() + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends] + + [tx1, tx2] = await wsm.add_pending_transactions([tx1, tx2], push=True, merge_spends=True, sign=True) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 1 + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 2 + ) + + await wallet_environments.full_node.wait_transaction_records_entered_mempool([tx1, tx2]) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index ba4c861d7dd2..44067d446e16 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -2289,7 +2289,7 @@ async def add_pending_transactions( tx_records = [ dataclasses.replace( tx, - spend_bundle=extra_spend_bundle if i == 0 else None, + spend_bundle=extra_spend_bundle if i == 0 else tx.spend_bundle, name=extra_spend_bundle.name() if i == 0 else bytes32.secret(), ) for i, tx in enumerate(tx_records)