diff --git a/hathor/builder/resources_builder.py b/hathor/builder/resources_builder.py index 88c38ff98..470f0c613 100644 --- a/hathor/builder/resources_builder.py +++ b/hathor/builder/resources_builder.py @@ -261,8 +261,11 @@ def create_resources(self) -> server.Site: # Websocket resource assert self.manager.tx_storage.indexes is not None - ws_factory = HathorAdminWebsocketFactory(metrics=self.manager.metrics, + ws_factory = HathorAdminWebsocketFactory(manager=self.manager, + metrics=self.manager.metrics, address_index=self.manager.tx_storage.indexes.addresses) + if self._args.disable_ws_history_streaming: + ws_factory.disable_history_streaming() ws_factory.start() root.putChild(b'ws', WebSocketResource(ws_factory)) diff --git a/hathor/cli/run_node.py b/hathor/cli/run_node.py index f5e85ee80..1362df6b4 100644 --- a/hathor/cli/run_node.py +++ b/hathor/cli/run_node.py @@ -155,6 +155,8 @@ def create_parser(cls) -> ArgumentParser: help='Launch embedded IPython kernel for remote debugging') parser.add_argument('--log-vertex-bytes', action='store_true', help='Log tx bytes for debugging') + parser.add_argument('--disable-ws-history-streaming', action='store_true', + help='Disable websocket history streaming API') return parser def prepare(self, *, register_resources: bool = True) -> None: diff --git a/hathor/cli/run_node_args.py b/hathor/cli/run_node_args.py index c67aaeebb..76994b7cb 100644 --- a/hathor/cli/run_node_args.py +++ b/hathor/cli/run_node_args.py @@ -80,3 +80,4 @@ class RunNodeArgs(BaseModel, extra=Extra.allow): x_ipython_kernel: bool nano_testnet: bool log_vertex_bytes: bool + disable_ws_history_streaming: bool diff --git a/hathor/wallet/hd_wallet.py b/hathor/wallet/hd_wallet.py index 3773ed4df..3c64d0c1f 100644 --- a/hathor/wallet/hd_wallet.py +++ b/hathor/wallet/hd_wallet.py @@ -173,6 +173,10 @@ def generate_new_key(self, index): new_key = self.chain_key.subkey(index) self._key_generated(new_key, index) + def get_xpub(self) -> str: + """Return wallet xpub after derivation.""" + return self.chain_key.as_text(as_private=False) + def _key_generated(self, key, index): """ Add generated key to self.keys and set last_generated_index diff --git a/hathor/websocket/exception.py b/hathor/websocket/exception.py new file mode 100644 index 000000000..20f83a0bf --- /dev/null +++ b/hathor/websocket/exception.py @@ -0,0 +1,27 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hathor.exception import HathorError + + +class InvalidXPub(HathorError): + """Raised when an invalid xpub is provided.""" + + +class LimitExceeded(HathorError): + """Raised when a limit is exceeded.""" + + +class InvalidAddress(HathorError): + """Raised when an invalid address is provided.""" diff --git a/hathor/websocket/factory.py b/hathor/websocket/factory.py index 2c7aa2d16..b96dcbdff 100644 --- a/hathor/websocket/factory.py +++ b/hathor/websocket/factory.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict, deque -from typing import Any, Optional, Union +from typing import Any, Optional from autobahn.exception import Disconnected from autobahn.twisted.websocket import WebSocketServerFactory @@ -22,11 +22,12 @@ from hathor.conf import HathorSettings from hathor.indexes import AddressIndex +from hathor.manager import HathorManager from hathor.metrics import Metrics from hathor.p2p.rate_limiter import RateLimiter from hathor.pubsub import EventArguments, HathorEvents from hathor.reactor import get_global_reactor -from hathor.util import json_dumpb, json_loadb, json_loads +from hathor.util import json_dumpb from hathor.websocket.protocol import HathorAdminWebsocketProtocol settings = HathorSettings() @@ -83,18 +84,25 @@ class HathorAdminWebsocketFactory(WebSocketServerFactory): max_subs_addrs_empty: Optional[int] = settings.WS_MAX_SUBS_ADDRS_EMPTY def buildProtocol(self, addr): - return self.protocol(self) + return self.protocol(self, is_history_streaming_enabled=self.is_history_streaming_enabled) - def __init__(self, metrics: Optional[Metrics] = None, address_index: Optional[AddressIndex] = None): + def __init__(self, + manager: HathorManager, + metrics: Optional[Metrics] = None, + address_index: Optional[AddressIndex] = None): """ :param metrics: If not given, a new one is created. :type metrics: :py:class:`hathor.metrics.Metrics` """ + self.manager = manager self.reactor = get_global_reactor() # Opened websocket connections so I can broadcast messages later # It contains only connections that have finished handshaking. self.connections: set[HathorAdminWebsocketProtocol] = set() + # Enable/disable history streaming over the websocket connection. + self.is_history_streaming_enabled: bool = True + # Websocket connection for each address self.address_connections: defaultdict[str, set[HathorAdminWebsocketProtocol]] = defaultdict(set) super().__init__() @@ -129,6 +137,12 @@ def stop(self): self._lc_send_metrics.stop() self.is_running = False + def disable_history_streaming(self) -> None: + """Disable history streaming for all connections.""" + self.is_history_streaming_enabled = False + for conn in self.connections: + self.disable_history_streaming() + def _setup_rate_limit(self): """ Set the limit of the RateLimiter and start the buffer deques with BUFFER_SIZE """ @@ -300,44 +314,33 @@ def process_deque(self, data_type): data_type=data_type) break - def handle_message(self, connection: HathorAdminWebsocketProtocol, data: Union[bytes, str]) -> None: - """ General message handler, detects type and deletages to specific handler.""" - if isinstance(data, bytes): - message = json_loadb(data) - else: - message = json_loads(data) - # we only handle ping messages for now - if message['type'] == 'ping': - self._handle_ping(connection, message) - elif message['type'] == 'subscribe_address': - self._handle_subscribe_address(connection, message) - elif message['type'] == 'unsubscribe_address': - self._handle_unsubscribe_address(connection, message) - - def _handle_ping(self, connection: HathorAdminWebsocketProtocol, message: dict[Any, Any]) -> None: - """ Handler for ping message, should respond with a simple {"type": "pong"}""" - payload = json_dumpb({'type': 'pong'}) - connection.sendMessage(payload, False) - def _handle_subscribe_address(self, connection: HathorAdminWebsocketProtocol, message: dict[Any, Any]) -> None: """ Handler for subscription to an address, consideirs subscription limits.""" - addr: str = message['address'] + address: str = message['address'] + success, errmsg = self.subscribe_address(connection, address) + response = { + 'type': 'subscribe_address', + 'address': address, + 'success': success, + } + if not success: + response['message'] = errmsg + connection.sendMessage(json_dumpb(response), False) + + def subscribe_address(self, connection: HathorAdminWebsocketProtocol, address: str) -> tuple[bool, str]: + """Subscribe an address to send real time updates to a websocket connection.""" subs: set[str] = connection.subscribed_to if self.max_subs_addrs_conn is not None and len(subs) >= self.max_subs_addrs_conn: - payload = json_dumpb({'message': 'Reached maximum number of subscribed ' - f'addresses ({self.max_subs_addrs_conn}).', - 'type': 'subscribe_address', 'success': False}) + return False, f'Reached maximum number of subscribed addresses ({self.max_subs_addrs_conn}).' + elif self.max_subs_addrs_empty is not None and ( self.address_index and _count_empty(subs, self.address_index) >= self.max_subs_addrs_empty ): - payload = json_dumpb({'message': 'Reached maximum number of subscribed ' - f'addresses without output ({self.max_subs_addrs_empty}).', - 'type': 'subscribe_address', 'success': False}) - else: - self.address_connections[addr].add(connection) - connection.subscribed_to.add(addr) - payload = json_dumpb({'type': 'subscribe_address', 'success': True}) - connection.sendMessage(payload, False) + return False, f'Reached maximum number of subscribed empty addresses ({self.max_subs_addrs_empty}).' + + self.address_connections[address].add(connection) + connection.subscribed_to.add(address) + return True, '' def _handle_unsubscribe_address(self, connection: HathorAdminWebsocketProtocol, message: dict[Any, Any]) -> None: """ Handler for unsubscribing from an address, also removes address connection set if it ends up empty.""" diff --git a/hathor/websocket/iterators.py b/hathor/websocket/iterators.py new file mode 100644 index 000000000..41f6e7298 --- /dev/null +++ b/hathor/websocket/iterators.py @@ -0,0 +1,151 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from collections.abc import AsyncIterable +from dataclasses import dataclass +from typing import AsyncIterator, Iterator, TypeAlias + +from twisted.internet.defer import Deferred + +from hathor.manager import HathorManager +from hathor.transaction import BaseTransaction +from hathor.types import AddressB58 +from hathor.websocket.exception import InvalidAddress, InvalidXPub, LimitExceeded + + +@dataclass(frozen=True, slots=True) +class AddressItem: + index: int + address: AddressB58 + + +@dataclass(frozen=True, slots=True) +class VertexItem: + vertex: BaseTransaction + + +class ManualAddressSequencer(AsyncIterable[AddressItem]): + """An async iterable that yields addresses from a list. More addresses + can be added while the iterator is being consumed. + """ + + ADDRESS_SIZE: int = 34 + MAX_PENDING_ADDRESSES_SIZE: int = 5_000 + + def __init__(self) -> None: + self.max_pending_addresses_size: int = self.MAX_PENDING_ADDRESSES_SIZE + self.pending_addresses: deque[AddressItem] = deque() + self.await_items: Deferred | None = None + + # Flag to mark when all addresses have been received so the iterator + # can stop yielding after the pending list of addresses is empty. + self._stop = False + + def _resume_iter(self) -> None: + """Resume yield addresses.""" + if self.await_items is None: + return + if not self.await_items.called: + self.await_items.callback(None) + + def add_addresses(self, addresses: list[AddressItem], last: bool) -> None: + """Add more addresses to be yielded. If `last` is true, the iterator + will stop when the pending list of items gets empty.""" + if len(self.pending_addresses) + len(addresses) > self.max_pending_addresses_size: + raise LimitExceeded + + # Validate addresses. + for item in addresses: + if len(item.address) != self.ADDRESS_SIZE: + raise InvalidAddress(item) + + self.pending_addresses.extend(addresses) + if last: + self._stop = True + self._resume_iter() + + def __aiter__(self) -> AsyncIterator[AddressItem]: + """Return an async iterator.""" + return self._async_iter() + + async def _async_iter(self) -> AsyncIterator[AddressItem]: + """Internal method that implements the async iterator.""" + while True: + while self.pending_addresses: + item = self.pending_addresses.popleft() + yield item + + if self._stop: + break + + self.await_items = Deferred() + await self.await_items + + +def iter_xpub_addresses(xpub_str: str, *, first_index: int = 0) -> Iterator[AddressItem]: + """An iterator that yields addresses derived from an xpub.""" + from pycoin.networks.registry import network_for_netcode + + from hathor.wallet.hd_wallet import _register_pycoin_networks + _register_pycoin_networks() + network = network_for_netcode('htr') + + xpub = network.parse.bip32(xpub_str) + if xpub is None: + raise InvalidXPub(xpub_str) + + idx = first_index + while True: + key = xpub.subkey(idx) + yield AddressItem(idx, AddressB58(key.address())) + idx += 1 + + +async def aiter_xpub_addresses(xpub: str, *, first_index: int = 0) -> AsyncIterator[AddressItem]: + """An async iterator that yields addresses derived from an xpub.""" + it = iter_xpub_addresses(xpub, first_index=first_index) + for item in it: + yield item + + +AddressSearch: TypeAlias = AsyncIterator[AddressItem | VertexItem] + + +async def gap_limit_search( + manager: HathorManager, + address_iter: AsyncIterable[AddressItem], + gap_limit: int +) -> AddressSearch: + """An async iterator that yields addresses and vertices, stopping when the gap limit is reached. + """ + assert manager.tx_storage.indexes is not None + assert manager.tx_storage.indexes.addresses is not None + addresses_index = manager.tx_storage.indexes.addresses + empty_addresses_counter = 0 + async for item in address_iter: + yield item # AddressItem + + vertex_counter = 0 + for vertex_id in addresses_index.get_sorted_from_address(item.address): + tx = manager.tx_storage.get_transaction(vertex_id) + yield VertexItem(tx) + vertex_counter += 1 + + if vertex_counter == 0: + empty_addresses_counter += 1 + if empty_addresses_counter >= gap_limit: + break + else: + empty_addresses_counter = 0 diff --git a/hathor/websocket/messages.py b/hathor/websocket/messages.py new file mode 100644 index 000000000..f8d2e6c9a --- /dev/null +++ b/hathor/websocket/messages.py @@ -0,0 +1,62 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from pydantic import Field + +from hathor.utils.pydantic import BaseModel + + +class WebSocketMessage(BaseModel): + pass + + +class CapabilitiesMessage(WebSocketMessage): + type: str = Field('capabilities', const=True) + capabilities: list[str] + + +class StreamBase(WebSocketMessage): + pass + + +class StreamErrorMessage(StreamBase): + type: str = Field('stream:history:error', const=True) + id: str + errmsg: str + + +class StreamBeginMessage(StreamBase): + type: str = Field('stream:history:begin', const=True) + id: str + + +class StreamEndMessage(StreamBase): + type: str = Field('stream:history:end', const=True) + id: str + + +class StreamVertexMessage(StreamBase): + type: str = Field('stream:history:vertex', const=True) + id: str + data: dict[str, Any] + + +class StreamAddressMessage(StreamBase): + type: str = Field('stream:history:address', const=True) + id: str + index: int + address: str + subscribed: bool diff --git a/hathor/websocket/protocol.py b/hathor/websocket/protocol.py index 5429b9506..1b3bde0bb 100644 --- a/hathor/websocket/protocol.py +++ b/hathor/websocket/protocol.py @@ -12,11 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, Union from autobahn.twisted.websocket import WebSocketServerProtocol from structlog import get_logger +from hathor.p2p.utils import format_address +from hathor.util import json_dumpb, json_loadb, json_loads +from hathor.websocket.exception import InvalidAddress, InvalidXPub, LimitExceeded +from hathor.websocket.iterators import ( + AddressItem, + AddressSearch, + ManualAddressSequencer, + aiter_xpub_addresses, + gap_limit_search, +) +from hathor.websocket.messages import CapabilitiesMessage, StreamErrorMessage, WebSocketMessage +from hathor.websocket.streamer import HistoryStreamer + if TYPE_CHECKING: from hathor.websocket.factory import HathorAdminWebsocketFactory @@ -30,23 +43,269 @@ class HathorAdminWebsocketProtocol(WebSocketServerProtocol): can send the data update to the clients """ - def __init__(self, factory: 'HathorAdminWebsocketFactory') -> None: + MAX_GAP_LIMIT: int = 10_000 + HISTORY_STREAMING_CAPABILITY: str = 'history-streaming' + + def __init__(self, + factory: 'HathorAdminWebsocketFactory', + is_history_streaming_enabled: bool) -> None: self.log = logger.new() self.factory = factory - self.subscribed_to: set[str] = set() super().__init__() + self.subscribed_to: set[str] = set() + + # Enable/disable history streaming for this connection. + self.is_history_streaming_enabled = is_history_streaming_enabled + self._history_streamer: HistoryStreamer | None = None + self._manual_address_iter: ManualAddressSequencer | None = None + + def get_capabilities(self) -> list[str]: + """Get a list of websocket capabilities.""" + capabilities = [] + if self.is_history_streaming_enabled: + capabilities.append(self.HISTORY_STREAMING_CAPABILITY) + return capabilities + + def send_capabilities(self) -> None: + """Send a capabilities message.""" + self.send_message(CapabilitiesMessage(capabilities=self.get_capabilities())) + + def disable_history_streaming(self) -> None: + """Disable history streaming in this connection.""" + self.is_history_streaming_enabled = False + if self._history_streamer: + self._history_streamer.stop(success=False) + self.log.info('websocket history streaming disabled') + + def get_short_remote(self) -> str: + """Get remote for logging.""" + assert self.transport is not None + return format_address(self.transport.getPeer()) + def onConnect(self, request): - self.log.info('connection opened, starting handshake...', request=request) + """Called by the websocket protocol when the connection is opened but it is still pending handshaking.""" + self.log = logger.new(remote=self.get_short_remote()) + self.log.info('websocket connection opened, starting handshake...') def onOpen(self) -> None: + """Called by the websocket protocol when the connection is established.""" self.factory.on_client_open(self) - self.log.info('connection established') + self.log.info('websocket connection established') + self.send_capabilities() def onClose(self, wasClean, code, reason): + """Called by the websocket protocol when the connection is closed.""" self.factory.on_client_close(self) - self.log.info('connection closed', reason=reason) + self.log.info('websocket connection closed', reason=reason) def onMessage(self, payload: Union[bytes, str], isBinary: bool) -> None: + """Called by the websocket protocol when a new message is received.""" self.log.debug('new message', payload=payload.hex() if isinstance(payload, bytes) else payload) - self.factory.handle_message(self, payload) + if isinstance(payload, bytes): + message = json_loadb(payload) + else: + message = json_loads(payload) + + _type = message.get('type') + + if _type == 'ping': + self._handle_ping(message) + elif _type == 'subscribe_address': + self.factory._handle_subscribe_address(self, message) + elif _type == 'unsubscribe_address': + self.factory._handle_unsubscribe_address(self, message) + elif _type == 'request:history:xpub': + self._open_history_xpub_streamer(message) + elif _type == 'request:history:manual': + self._handle_history_manual_streamer(message) + elif _type == 'request:history:stop': + self._stop_streamer(message) + + def _handle_ping(self, message: dict[Any, Any]) -> None: + """Handle ping message, should respond with a simple {"type": "pong"}""" + payload = json_dumpb({'type': 'pong'}) + self.sendMessage(payload, False) + + def fail_if_history_streaming_is_disabled(self) -> bool: + """Return false if the history streamer is enabled. Otherwise, it sends an + error message and returns true.""" + if self.is_history_streaming_enabled: + return False + + self.send_message(StreamErrorMessage( + id='', + errmsg='Streaming history is disabled.' + )) + return True + + def _create_streamer(self, stream_id: str, search: AddressSearch) -> None: + """Create the streamer and handle its callbacks.""" + self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search) + deferred = self._history_streamer.start() + deferred.addBoth(self._streamer_callback) + return + + def _open_history_xpub_streamer(self, message: dict[Any, Any]) -> None: + """Handle request to stream transactions using an xpub.""" + if self.fail_if_history_streaming_is_disabled(): + return + + stream_id = message['id'] + + if self._history_streamer is not None: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg='Streaming is already opened.' + )) + return + + xpub = message['xpub'] + gap_limit = message.get('gap-limit', 20) + first_index = message.get('first-index', 0) + if gap_limit > self.MAX_GAP_LIMIT: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg=f'GAP limit is too big. Maximum: {self.MAX_GAP_LIMIT}' + )) + return + + try: + address_iter = aiter_xpub_addresses(xpub, first_index=first_index) + except InvalidXPub: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg=f'Invalid XPub: {xpub}' + )) + return + + search = gap_limit_search(self.factory.manager, address_iter, gap_limit) + self._create_streamer(stream_id, search) + self.log.info('opening a websocket xpub streaming', + stream_id=stream_id, + xpub=xpub, + gap_limit=gap_limit, + first_index=first_index) + + def _handle_history_manual_streamer(self, message: dict[Any, Any]) -> None: + """Handle request to stream transactions using a list of addresses.""" + if self.fail_if_history_streaming_is_disabled(): + return + + stream_id = message['id'] + addresses: list[AddressItem] = [AddressItem(idx, address) for idx, address in message.get('addresses', [])] + first = message.get('first', False) + last = message.get('last', False) + + if self._history_streamer is not None: + if first or self._history_streamer.stream_id != stream_id: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg='Streaming is already opened.' + )) + return + + if not self._add_addresses_to_manual_iter(stream_id, addresses, last): + return + + self.log.info('Adding addresses to a websocket manual streaming', + stream_id=stream_id, + addresses=addresses, + last=last) + return + + gap_limit = message.get('gap-limit', 20) + if gap_limit > self.MAX_GAP_LIMIT: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg=f'GAP limit is too big. Maximum: {self.MAX_GAP_LIMIT}' + )) + return + + if not first: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg='Streaming not found. You must send first=true in your first message.' + )) + return + + address_iter = ManualAddressSequencer() + self._manual_address_iter = address_iter + if not self._add_addresses_to_manual_iter(stream_id, addresses, last): + self._manual_address_iter = None + return + + search = gap_limit_search(self.factory.manager, address_iter, gap_limit) + self._create_streamer(stream_id, search) + self.log.info('opening a websocket manual streaming', + stream_id=stream_id, + addresses=addresses, + gap_limit=gap_limit, + last=last) + + def _streamer_callback(self, success: bool) -> None: + """Callback used to identify when the streamer has ended.""" + assert self._history_streamer is not None + self.log.info('websocket xpub streaming has been finished', + stream_id=self._history_streamer.stream_id, + success=success, + sent_addresses=self._history_streamer.stats_sent_addresses, + sent_vertices=self._history_streamer.stats_sent_vertices) + self._history_streamer = None + self._manual_address_iter = None + + def _stop_streamer(self, message: dict[Any, Any]) -> None: + """Handle request to stop the current streamer.""" + stream_id: str = message.get('id', '') + + if self._history_streamer is None: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg='No streaming opened.' + )) + return + + assert self._history_streamer is not None + + if self._history_streamer.stream_id != stream_id: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg='Current stream has a different id.' + )) + return + + self._history_streamer.stop(success=False) + self.log.info('stopping a websocket xpub streaming', stream_id=stream_id) + + def send_message(self, message: WebSocketMessage) -> None: + """Send a typed message.""" + payload = message.json_dumpb() + self.sendMessage(payload) + + def subscribe_address(self, address: str) -> tuple[bool, str]: + """Subscribe to receive real-time messages for all vertices related to an address.""" + return self.factory.subscribe_address(self, address) + + def _add_addresses_to_manual_iter(self, stream_id: str, addresses: list[AddressItem], last: bool) -> bool: + """Add addresses to manual address iter and returns true if it succeeds.""" + assert self._manual_address_iter is not None + try: + self._manual_address_iter.add_addresses(addresses, last) + except LimitExceeded: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg='List of addresses is too long.' + )) + return False + except InvalidAddress as exc: + self.send_message(StreamErrorMessage( + id=stream_id, + errmsg=f'Invalid address: {exc}' + )) + return False + + self.log.info('Adding addresses to a websocket manual streaming', + stream_id=stream_id, + addresses=addresses, + last=last) + return True diff --git a/hathor/websocket/streamer.py b/hathor/websocket/streamer.py new file mode 100644 index 000000000..f116fc36e --- /dev/null +++ b/hathor/websocket/streamer.py @@ -0,0 +1,185 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IPushProducer +from twisted.internet.task import deferLater +from zope.interface import implementer + +from hathor.websocket.iterators import AddressItem, AddressSearch, VertexItem +from hathor.websocket.messages import ( + StreamAddressMessage, + StreamBase, + StreamBeginMessage, + StreamEndMessage, + StreamErrorMessage, + StreamVertexMessage, +) + +if TYPE_CHECKING: + from hathor.websocket.protocol import HathorAdminWebsocketProtocol + + +@implementer(IPushProducer) +class HistoryStreamer: + """A producer that pushes addresses and transactions to a websocket connection. + Each pushed address is automatically subscribed for real-time updates. + + Streaming messages: + + 1. `stream:history:begin`: mark the beginning of a streaming. + 2. `stream:history:address`: mark the beginning of a new address. + 3. `stream:history:vertex`: vertex information in JSON format. + 4. `stream:history:vertex`: vertex information in JSON format. + 5. `stream:history:vertex`: vertex information in JSON format. + 6. `stream:history:vertex`: vertex information in JSON format. + 7. `stream:history:address`: mark the beginning of another address, so the previous address has been finished. + 8. `stream:history:address`: mark the beginning of another address, so the previous address has been finished. + 9. `stream:history:address`: mark the beginning of another address, so the previous address has been finished. + 10. `stream:history:end`: mark the end of the streaming. + + Notice that the streaming might send two or more `address` messages in a row if there are empty addresses. + """ + + STATS_LOG_INTERVAL = 10_000 + + def __init__(self, + *, + protocol: 'HathorAdminWebsocketProtocol', + stream_id: str, + search: AddressSearch) -> None: + self.protocol = protocol + self.stream_id = stream_id + self.search_iter = aiter(search) + + self.reactor = self.protocol.factory.manager.reactor + + self.max_seconds_locking_event_loop = 1 + + self.stats_log_interval = self.STATS_LOG_INTERVAL + self.stats_total_messages: int = 0 + self.stats_sent_addresses: int = 0 + self.stats_sent_vertices: int = 0 + + self._paused = False + self._stop = False + + def start(self) -> Deferred[bool]: + """Start streaming items.""" + self.send_message(StreamBeginMessage(id=self.stream_id)) + + # The websocket connection somehow instantiates an twisted.web.http.HTTPChannel object + # which register a producer. It seems the HTTPChannel is not used anymore after switching + # to websocket but it keep registered. So we have to unregister before registering a new + # producer. + if self.protocol.transport.producer: + self.protocol.unregisterProducer() + + self.protocol.registerProducer(self, True) + self.deferred: Deferred[bool] = Deferred() + self.resumeProducing() + return self.deferred + + def stop(self, success: bool) -> None: + """Stop streaming items.""" + self._stop = True + self.protocol.unregisterProducer() + self.deferred.callback(success) + + def pauseProducing(self) -> None: + """Pause streaming. Called by twisted.""" + self._paused = True + + def stopProducing(self) -> None: + """Stop streaming. Called by twisted.""" + self._stop = True + self.stop(False) + + def resumeProducing(self) -> None: + """Resume streaming. Called by twisted.""" + self._paused = False + self._run() + + def _run(self) -> None: + """Run the streaming main loop.""" + coro = self._async_run() + Deferred.fromCoroutine(coro) + + async def _async_run(self): + """Internal method that runs the streaming main loop.""" + t0 = self.reactor.seconds() + + async for item in self.search_iter: + # The methods `pauseProducing()` and `stopProducing()` might be called during the + # call to `self.protocol.sendMessage()`. So both `_paused` and `_stop` might change + # during the loop. + if self._paused or self._stop: + break + + match item: + case AddressItem(): + subscribed, errmsg = self.protocol.subscribe_address(item.address) + + if not subscribed: + self.send_message(StreamErrorMessage( + id=self.stream_id, + errmsg=f'Address subscription failed: {errmsg}' + )) + self.stop(False) + return + + self.stats_sent_addresses += 1 + self.send_message(StreamAddressMessage( + id=self.stream_id, + index=item.index, + address=item.address, + subscribed=subscribed, + )) + + case VertexItem(): + self.stats_sent_vertices += 1 + self.send_message(StreamVertexMessage( + id=self.stream_id, + data=item.vertex.to_json_extended(), + )) + + case _: + assert False + + self.stats_total_messages += 1 + if self.stats_total_messages % self.stats_log_interval == 0: + self.protocol.log.info('websocket streaming statistics', + total_messages=self.stats_total_messages, + sent_vertices=self.stats_sent_vertices, + sent_addresses=self.stats_sent_addresses) + + dt = self.reactor.seconds() - t0 + if dt > self.max_seconds_locking_event_loop: + # Let the event loop run at least once. + await deferLater(self.reactor, 0, lambda: None) + t0 = self.reactor.seconds() + + else: + if self._stop: + # If the streamer has been stopped, there is nothing else to do. + return + self.send_message(StreamEndMessage(id=self.stream_id)) + self.stop(True) + + def send_message(self, message: StreamBase) -> None: + """Send a message to the websocket connection.""" + payload = message.json_dumpb() + self.protocol.sendMessage(payload) diff --git a/tests/sysctl/test_websocket.py b/tests/sysctl/test_websocket.py index 3c7749f3e..920eb3b87 100644 --- a/tests/sysctl/test_websocket.py +++ b/tests/sysctl/test_websocket.py @@ -5,8 +5,15 @@ class WebsocketSysctlTestCase(unittest.TestCase): + _enable_sync_v1 = True + _enable_sync_v2 = True + + def setUp(self): + super().setUp() + self.manager = self.create_peer('testnet') + def test_max_subs_addrs_conn(self): - ws_factory = HathorAdminWebsocketFactory() + ws_factory = HathorAdminWebsocketFactory(self.manager) sysctl = WebsocketManagerSysctl(ws_factory) sysctl.unsafe_set('max_subs_addrs_conn', 10) @@ -25,7 +32,7 @@ def test_max_subs_addrs_conn(self): sysctl.unsafe_set('max_subs_addrs_conn', -2) def test_max_subs_addrs_empty(self): - ws_factory = HathorAdminWebsocketFactory() + ws_factory = HathorAdminWebsocketFactory(self.manager) sysctl = WebsocketManagerSysctl(ws_factory) sysctl.unsafe_set('max_subs_addrs_empty', 10) diff --git a/tests/unittest.py b/tests/unittest.py index cccc417a9..bac8fca7c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -118,6 +118,7 @@ def setUp(self) -> None: self.tmpdirs: list[str] = [] self.clock = TestMemoryReactorClock() self.clock.advance(time.time()) + self.reactor = self.clock self.log = logger.new() self.reset_peer_id_pool() self.seed = secrets.randbits(64) if self.seed_config is None else self.seed_config diff --git a/tests/websocket/test_async_iterators.py b/tests/websocket/test_async_iterators.py new file mode 100644 index 000000000..a3435e723 --- /dev/null +++ b/tests/websocket/test_async_iterators.py @@ -0,0 +1,139 @@ +from typing import AsyncIterator, TypeVar + +from twisted.internet.defer import Deferred + +from hathor.wallet import HDWallet +from hathor.websocket.exception import InvalidAddress, InvalidXPub +from hathor.websocket.iterators import ( + AddressItem, + ManualAddressSequencer, + VertexItem, + aiter_xpub_addresses, + gap_limit_search, +) +from tests.unittest import TestCase +from tests.utils import GENESIS_ADDRESS_B58 + +T = TypeVar('T') + + +async def async_islice(iterable: AsyncIterator[T], stop: int) -> AsyncIterator[T]: + count = 0 + async for item in iterable: + if count >= stop: + break + yield item + count += 1 + + +class AsyncIteratorsTestCase(TestCase): + _enable_sync_v1 = True + _enable_sync_v2 = True + + def setUp(self) -> None: + super().setUp() + + self.manager = self.create_peer('mainnet', wallet_index=True) + self.settings = self.manager._settings + + # Create wallet. + wallet = HDWallet() + wallet.unlock(self.manager.tx_storage) + + # Create xpub and list of addresses. + self.xpub = wallet.get_xpub() + self.xpub_addresses = [ + AddressItem(idx, wallet.get_address(wallet.get_key_at_index(idx))) + for idx in range(20) + ] + + async def test_xpub_sequencer_default_first_index(self) -> None: + xpub = self.xpub + expected_result = self.xpub_addresses + + sequencer = aiter_xpub_addresses(xpub) + result = [item async for item in async_islice(aiter(sequencer), len(expected_result))] + self.assertEqual(result, expected_result) + + async def test_xpub_sequencer_other_first_index(self) -> None: + xpub = self.xpub + first_index = 8 + expected_result = self.xpub_addresses[first_index:] + + sequencer = aiter_xpub_addresses(xpub, first_index=first_index) + result = [item async for item in async_islice(aiter(sequencer), len(expected_result))] + self.assertEqual(result, expected_result) + + async def test_xpub_sequencer_invalid(self) -> None: + with self.assertRaises(InvalidXPub): + async for _ in aiter_xpub_addresses('invalid xpub'): + pass + + async def test_manual_invalid(self) -> None: + address_iter = ManualAddressSequencer() + with self.assertRaises(InvalidAddress): + address_iter.add_addresses([AddressItem(0, 'a')], last=True) + + async def test_manual_last_true(self) -> None: + expected_result = self.xpub_addresses + + iterable = ManualAddressSequencer() + iterable.add_addresses(expected_result, last=True) + + result = [item async for item in iterable] + self.assertEqual(result, expected_result) + + async def test_manual_two_tranches(self) -> None: + expected_result = self.xpub_addresses + + iterable = ManualAddressSequencer() + n = 8 + iterable.add_addresses(expected_result[:n], last=False) + + result = [] + is_running = False + + async def collect_results(): + nonlocal is_running + nonlocal result + is_running = True + result = [item async for item in iterable] + is_running = False + + self.reactor.callLater(0, lambda: Deferred.fromCoroutine(collect_results())) + self.reactor.advance(5) + self.assertTrue(is_running) + + self.reactor.callLater(0, lambda: iterable.add_addresses(expected_result[n:], last=True)) + self.reactor.advance(5) + self.assertFalse(is_running) + + self.assertEqual(result, expected_result) + + async def test_gap_limit_xpub(self) -> None: + xpub = self.xpub + gap_limit = 8 + expected_result = self.xpub_addresses[:gap_limit] + + address_iter = aiter_xpub_addresses(xpub) + search = gap_limit_search(self.manager, address_iter, gap_limit=gap_limit) + + result = [item async for item in search] + self.assertEqual(result, expected_result) + + async def test_gap_limit_manual(self) -> None: + genesis = self.manager.tx_storage.get_genesis(self.settings.GENESIS_BLOCK_HASH) + genesis_address = GENESIS_ADDRESS_B58 + + gap_limit = 8 + addresses: list[AddressItem] = [AddressItem(0, genesis_address)] + self.xpub_addresses + expected_result: list[AddressItem | VertexItem] = list(addresses[:gap_limit + 1]) + expected_result.insert(1, VertexItem(genesis)) + + address_iter = ManualAddressSequencer() + # Adding more addresses than the gap limit. + address_iter.add_addresses(addresses, last=True) + search = gap_limit_search(self.manager, address_iter, gap_limit=gap_limit) + + result = [item async for item in search] + self.assertEqual(result, expected_result) diff --git a/tests/websocket/test_streamer.py b/tests/websocket/test_streamer.py new file mode 100644 index 000000000..a76148952 --- /dev/null +++ b/tests/websocket/test_streamer.py @@ -0,0 +1,99 @@ +import json +from typing import Any, Iterator + +from twisted.internet.testing import StringTransport + +from hathor.wallet import HDWallet +from hathor.websocket.factory import HathorAdminWebsocketFactory +from hathor.websocket.iterators import AddressItem, ManualAddressSequencer, gap_limit_search +from hathor.websocket.streamer import HistoryStreamer +from tests.unittest import TestCase +from tests.utils import GENESIS_ADDRESS_B58 + + +class AsyncIteratorsTestCase(TestCase): + _enable_sync_v1 = True + _enable_sync_v2 = True + + WS_PROTOCOL_MESSAGE_SEPARATOR = b'\x81' + + def test_streamer(self) -> None: + manager = self.create_peer('mainnet', wallet_index=True) + settings = manager._settings + + # Settings. + stream_id = 'A001' + gap_limit = 8 + + # Get genesis information. + genesis = manager.tx_storage.get_genesis(settings.GENESIS_BLOCK_HASH) + genesis_address = GENESIS_ADDRESS_B58 + + # Create wallet. + wallet = HDWallet() + wallet.unlock(manager.tx_storage) + + # Create list of addresses. + addresses: list[AddressItem] = [AddressItem(0, genesis_address)] + for idx in range(1, 30): + addresses.append(AddressItem(idx, wallet.get_address(wallet.get_key_at_index(idx)))) + + # Create the expected result. + expected_result: list[dict[str, Any]] = [{'type': 'stream:history:begin', 'id': stream_id}] + expected_result += [ + { + 'type': 'stream:history:address', + 'id': stream_id, + 'index': item.index, + 'address': item.address, + 'subscribed': True + } + for item in addresses[:gap_limit + 1] + ] + expected_result.insert(2, { + 'type': 'stream:history:vertex', + 'id': stream_id, + 'data': genesis.to_json_extended(), + }) + expected_result.append({'type': 'stream:history:end', 'id': stream_id}) + + # Create both the address iterator and the GAP limit searcher. + address_iter = ManualAddressSequencer() + address_iter.add_addresses(addresses, last=True) + search = gap_limit_search(manager, address_iter, gap_limit=gap_limit) + + # Create the websocket factory and protocol. + factory = HathorAdminWebsocketFactory(manager) + factory.openHandshakeTimeout = 0 + protocol = factory.buildProtocol(None) + + # Create the transport and create a fake connection. + transport = StringTransport() + protocol.makeConnection(transport) + factory.connections.add(protocol) + protocol.state = protocol.STATE_OPEN + + # Create the history streamer. + streamer = HistoryStreamer(protocol=protocol, stream_id=stream_id, search=search) + streamer.start() + + # Run the streamer. + manager.reactor.advance(10) + + # Check the results. + items_iter = self._parse_ws_raw(transport.value()) + result = list(items_iter) + self.assertEqual(result, expected_result) + + def _parse_ws_raw(self, content: bytes) -> Iterator[dict]: + raw_messages = content.split(self.WS_PROTOCOL_MESSAGE_SEPARATOR) + for x in raw_messages: + if not x: + continue + if x[-1:] != b'}': + continue + idx = x.find(b'{') + if idx == -1: + continue + json_raw = x[idx:] + yield json.loads(json_raw) diff --git a/tests/websocket/test_websocket.py b/tests/websocket/test_websocket.py index 1583bba09..3b1df54cd 100644 --- a/tests/websocket/test_websocket.py +++ b/tests/websocket/test_websocket.py @@ -22,7 +22,7 @@ def setUp(self): self.network = 'testnet' self.manager = self.create_peer(self.network, wallet_index=True) - self.factory = HathorAdminWebsocketFactory(self.manager.metrics) + self.factory = HathorAdminWebsocketFactory(self.manager, self.manager.metrics) self.factory.subscribe(self.manager.pubsub) self.factory._setup_rate_limit() self.factory.openHandshakeTimeout = 0