diff --git a/hathor/builder.py b/hathor/builder.py index 822654249..31ff70bc9 100644 --- a/hathor/builder.py +++ b/hathor/builder.py @@ -26,6 +26,7 @@ from twisted.web import server from twisted.web.resource import Resource +from hathor.event import EventManager from hathor.exception import BuilderError from hathor.indexes import IndexesManager from hathor.manager import HathorManager @@ -56,6 +57,7 @@ def create_manager(self, reactor: PosixReactorBase, args: Namespace) -> HathorMa from hathor.conf.get_settings import get_settings_module from hathor.daa import TestMode, _set_test_mode from hathor.event.storage import EventMemoryStorage, EventRocksDBStorage, EventStorage + from hathor.event.websocket.factory import EventWebsocketFactory from hathor.p2p.netfilter.utils import add_peer_id_blacklist from hathor.p2p.peer_discovery import BootstrapPeerDiscovery, DNSPeerDiscovery from hathor.storage import RocksDBStorage @@ -90,13 +92,15 @@ def create_manager(self, reactor: PosixReactorBase, args: Namespace) -> HathorMa tx_storage: TransactionStorage rocksdb_storage: RocksDBStorage - event_storage: Optional[EventStorage] = None + self.event_storage: Optional[EventStorage] = None + self.event_ws_factory: Optional[EventWebsocketFactory] = None + if args.memory_storage: self.check_or_raise(not args.data, '--data should not be used with --memory-storage') # if using MemoryStorage, no need to have cache tx_storage = TransactionMemoryStorage() if args.x_enable_event_queue: - event_storage = EventMemoryStorage() + self.event_storage = EventMemoryStorage() self.check_or_raise(not args.x_rocksdb_indexes, 'RocksDB indexes require RocksDB data') self.log.info('with storage', storage_class=type(tx_storage).__name__) else: @@ -109,7 +113,7 @@ def create_manager(self, reactor: PosixReactorBase, args: Namespace) -> HathorMa with_index=(not args.cache), use_memory_indexes=args.memory_indexes) if args.x_enable_event_queue: - event_storage = EventRocksDBStorage(rocksdb_storage) + self.event_storage = EventRocksDBStorage(rocksdb_storage) self.log.info('with storage', storage_class=type(tx_storage).__name__, path=args.data) if args.cache: @@ -135,6 +139,17 @@ def create_manager(self, reactor: PosixReactorBase, args: Namespace) -> HathorMa pubsub = PubSubManager(reactor) + event_manager: Optional[EventManager] = None + if args.x_enable_event_queue: + assert self.event_storage is not None, 'cannot create EventManager without EventStorage' + self.event_ws_factory = EventWebsocketFactory(reactor, self.event_storage) + event_manager = EventManager( + event_storage=self.event_storage, + event_ws_factory=self.event_ws_factory, + pubsub=pubsub, + reactor=reactor + ) + if args.wallet_index and tx_storage.indexes is not None: self.log.debug('enable wallet indexes') self.enable_wallet_index(tx_storage.indexes, pubsub) @@ -150,7 +165,7 @@ def create_manager(self, reactor: PosixReactorBase, args: Namespace) -> HathorMa network=network, hostname=hostname, tx_storage=tx_storage, - event_storage=event_storage, + event_manager=event_manager, wallet=self.wallet, stratum_port=args.stratum, ssl=True, @@ -481,6 +496,10 @@ def create_resources(self, args: Namespace) -> server.Site: ws_factory.subscribe(self.manager.pubsub) + # Event websocket resource + if args.x_enable_event_queue and self.event_ws_factory is not None: + root.putChild(b'event_ws', WebSocketResource(self.event_ws_factory)) + # Websocket stats resource root.putChild(b'websocket_stats', WebsocketStatsResource(ws_factory)) diff --git a/hathor/conf/unittests.py b/hathor/conf/unittests.py index 88e0853c7..d2c43c4bd 100644 --- a/hathor/conf/unittests.py +++ b/hathor/conf/unittests.py @@ -32,4 +32,5 @@ GENESIS_TX2_HASH=bytes.fromhex('33e14cb555a96967841dcbe0f95e9eab5810481d01de8f4f73afb8cce365e869'), REWARD_SPEND_MIN_BLOCKS=10, SLOW_ASSERTS=True, + ENABLE_EVENT_QUEUE_FEATURE=True, ) diff --git a/hathor/event/base_event.py b/hathor/event/base_event.py index 64923f52a..8c7be101c 100644 --- a/hathor/event/base_event.py +++ b/hathor/event/base_event.py @@ -12,24 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Dict, Optional +from pydantic import NonNegativeInt -@dataclass -class BaseEvent: +from hathor.utils.pydantic import BaseModel + + +class BaseEvent(BaseModel): # Full node id, because different full nodes can have different sequences of events peer_id: str # Event unique id, determines event order - id: int + id: NonNegativeInt # Timestamp in which the event was emitted, this follows the unix_timestamp format, it's only informative, events # aren't guaranteed to always have sequential timestamps, for example, if the system clock changes between two # events it's possible that timestamps will temporarily decrease. timestamp: float # One of the event types - type: str + type: str # TODO: Convert type and data to enum and classes # Variable for event type data: Dict # Used to link events, for example, many TX_METADATA_CHANGED will have the same group_id when they belong to the # same reorg process - group_id: Optional[int] = None + group_id: Optional[NonNegativeInt] = None diff --git a/hathor/event/event_manager.py b/hathor/event/event_manager.py index 386e96013..0db787623 100644 --- a/hathor/event/event_manager.py +++ b/hathor/event/event_manager.py @@ -18,6 +18,7 @@ from hathor.event.base_event import BaseEvent from hathor.event.storage import EventStorage +from hathor.event.websocket import EventWebsocketFactory from hathor.pubsub import EventArguments, HathorEvents, PubSubManager from hathor.util import Reactor @@ -85,61 +86,133 @@ def _extract_reorg(args: EventArguments) -> Dict[str, Any]: } -def build_event_data(event: HathorEvents, event_args: EventArguments) -> Dict[str, Any]: +def _build_event_data(event_type: HathorEvents, event_args: EventArguments) -> Dict[str, Any]: """Extract and build event data from event_args for a given event type.""" - event_extract_fn = _EVENT_EXTRACT_MAP.get(event) + event_extract_fn = _EVENT_EXTRACT_MAP.get(event_type) if event_extract_fn is None: - raise ValueError(f'The given event type ({event}) is not a supported event') + raise ValueError(f'The given event type ({event_type}) is not a supported event') return event_extract_fn(event_args) class EventManager: - def __init__(self, event_storage: EventStorage, reactor: Reactor, peer_id: str): + """Class that manages integration events. + + Events are received from PubSub, persisted on the storage and sent to WebSocket clients. + """ + + _peer_id: str + + def __init__( + self, + event_storage: EventStorage, + event_ws_factory: EventWebsocketFactory, + pubsub: PubSubManager, + reactor: Reactor + ): self.log = logger.new() - self.clock = reactor - self.event_storage = event_storage - last_event = event_storage.get_last_event() - last_event_type = HathorEvents(last_event.type) if last_event is not None else None + + self._clock = reactor + self._event_storage = event_storage + self._event_ws_factory = event_ws_factory + self._pubsub = pubsub + + self._last_event = self._event_storage.get_last_event() + self._last_existing_group_id = self._event_storage.get_last_group_id() + + self._assert_closed_event_group() + self._subscribe_events() + + def start(self, peer_id: str) -> None: + self._peer_id = peer_id + self._event_ws_factory.start() + + def stop(self): + self._event_ws_factory.stop() + + def _assert_closed_event_group(self): # XXX: we must check that the last event either does not belong to an event group or that it just closed an # event group, because we cannot resume an open group of events that wasn't properly closed before exit assert ( - last_event is None or - last_event.group_id is None or - last_event_type in _GROUP_END_EVENTS + self._event_group_is_closed() ), 'an unclosed event group was detected, which indicates the node crashed, cannot resume' - self._next_event_id = 0 if last_event is None else last_event.id + 1 - last_group_id = event_storage.get_last_group_id() - self._next_group_id = 0 if last_group_id is None else last_group_id + 1 - self._current_group_id: Optional[int] = None - self._peer_id = peer_id - def subscribe(self, pubsub: PubSubManager) -> None: + def _event_group_is_closed(self): + return ( + self._last_event is None or + self._last_event.group_id is None or + HathorEvents(self._last_event.type) in _GROUP_END_EVENTS + ) + + def _subscribe_events(self): """ Subscribe to defined events for the pubsub received """ for event in _SUBSCRIBE_EVENTS: - pubsub.subscribe(event, self._persist_event) + self._pubsub.subscribe(event, self._handle_event) + + def _handle_event(self, event_type: HathorEvents, event_args: EventArguments) -> None: + create_event_fn: Callable[[HathorEvents, EventArguments], BaseEvent] - def _persist_event(self, event: HathorEvents, event_args: EventArguments) -> None: - group_id: Optional[int] - if event in _GROUP_START_EVENTS: - assert self._current_group_id is None, 'cannot start an event group before the last one is ended' - group_id = self._next_group_id + if event_type in _GROUP_START_EVENTS: + create_event_fn = self._create_group_start_event + elif event_type in _GROUP_END_EVENTS: + create_event_fn = self._create_group_end_event else: - group_id = self._current_group_id - if event in _GROUP_END_EVENTS: - assert self._current_group_id is not None, 'cannot end group twice' - event_to_store = BaseEvent( - id=self._next_event_id, + create_event_fn = self._create_non_group_edge_event + + event = create_event_fn(event_type, event_args) + + self._event_storage.save_event(event) + self._event_ws_factory.broadcast_event(event) + + self._last_event = event + + def _create_group_start_event(self, event_type: HathorEvents, event_args: EventArguments) -> BaseEvent: + assert self._event_group_is_closed(), 'A new event group cannot be started as one is already in progress.' + + new_group_id = 0 if self._last_existing_group_id is None else self._last_existing_group_id + 1 + + self._last_existing_group_id = new_group_id + + return self._create_event( + event_type=event_type, + event_args=event_args, + group_id=new_group_id, + ) + + def _create_group_end_event(self, event_type: HathorEvents, event_args: EventArguments) -> BaseEvent: + assert self._last_event is not None, 'Cannot end event group if there are no events.' + assert not self._event_group_is_closed(), 'Cannot end event group as none is in progress.' + + return self._create_event( + event_type=event_type, + event_args=event_args, + group_id=self._last_event.group_id, + ) + + def _create_non_group_edge_event(self, event_type: HathorEvents, event_args: EventArguments) -> BaseEvent: + group_id = None + + if not self._event_group_is_closed(): + assert self._last_event is not None, 'Cannot continue event group if there are no events.' + group_id = self._last_event.group_id + + return self._create_event( + event_type=event_type, + event_args=event_args, + group_id=group_id, + ) + + def _create_event( + self, + event_type: HathorEvents, + event_args: EventArguments, + group_id: Optional[int], + ) -> BaseEvent: + return BaseEvent( + id=0 if self._last_event is None else self._last_event.id + 1, peer_id=self._peer_id, - timestamp=self.clock.seconds(), - type=event.value, - data=build_event_data(event, event_args), + timestamp=self._clock.seconds(), + type=event_type.value, + data=_build_event_data(event_type, event_args), group_id=group_id, ) - self.event_storage.save_event(event_to_store) - self._next_event_id += 1 - if event in _GROUP_START_EVENTS: - self._current_group_id = self._next_group_id - self._next_group_id += 1 - if event in _GROUP_END_EVENTS: - self._current_group_id = None diff --git a/hathor/event/storage/event_storage.py b/hathor/event/storage/event_storage.py index a16ffc59f..0c16a4b33 100644 --- a/hathor/event/storage/event_storage.py +++ b/hathor/event/storage/event_storage.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Optional +from typing import Iterator, Optional from hathor.event.base_event import BaseEvent @@ -38,3 +38,8 @@ def get_last_event(self) -> Optional[BaseEvent]: def get_last_group_id(self) -> Optional[int]: """ Get the last group-id that was emitted, this is used to help resume when restarting.""" raise NotImplementedError + + @abstractmethod + def iter_from_event(self, key: int) -> Iterator[BaseEvent]: + """ Iterate through events starting from the event with the given key""" + raise NotImplementedError diff --git a/hathor/event/storage/memory_storage.py b/hathor/event/storage/memory_storage.py index c61bb0471..4472de8bb 100644 --- a/hathor/event/storage/memory_storage.py +++ b/hathor/event/storage/memory_storage.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Iterator, List, Optional from hathor.event.base_event import BaseEvent from hathor.event.storage.event_storage import EventStorage @@ -25,8 +25,6 @@ def __init__(self): self._last_group_id: Optional[int] = None def save_event(self, event: BaseEvent) -> None: - if event.id < 0: - raise ValueError('event.id must be non-negative') if event.id != len(self._events): raise ValueError('invalid event.id, ids must be sequential and leave no gaps') self._last_event = event @@ -36,7 +34,7 @@ def save_event(self, event: BaseEvent) -> None: def get_event(self, key: int) -> Optional[BaseEvent]: if key < 0: - raise ValueError('key must be non-negative') + raise ValueError(f'event.id \'{key}\' must be non-negative') if key >= len(self._events): return None event = self._events[key] @@ -48,3 +46,11 @@ def get_last_event(self) -> Optional[BaseEvent]: def get_last_group_id(self) -> Optional[int]: return self._last_group_id + + def iter_from_event(self, key: int) -> Iterator[BaseEvent]: + if key < 0: + raise ValueError(f'event.id \'{key}\' must be non-negative') + + while key < len(self._events): + yield self._events[key] + key += 1 diff --git a/hathor/event/storage/rocksdb_storage.py b/hathor/event/storage/rocksdb_storage.py index 83c8b48b3..30848d373 100644 --- a/hathor/event/storage/rocksdb_storage.py +++ b/hathor/event/storage/rocksdb_storage.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Iterator, Optional from hathor.event.base_event import BaseEvent from hathor.event.storage.event_storage import EventStorage from hathor.storage.rocksdb_storage import RocksDBStorage from hathor.transaction.util import int_to_bytes -from hathor.util import json_dumpb, json_loadb +from hathor.util import json_dumpb _CF_NAME_EVENT = b'event' _CF_NAME_META = b'event-metadata' @@ -33,16 +33,15 @@ def __init__(self, rocksdb_storage: RocksDBStorage): self._last_event: Optional[BaseEvent] = self._db_get_last_event() self._last_group_id: Optional[int] = self._db_get_last_group_id() - def _load_from_bytes(self, event_data: bytes) -> BaseEvent: - event_dict = json_loadb(event_data) - return BaseEvent( - id=event_dict['id'], - peer_id=event_dict['peer_id'], - timestamp=event_dict['timestamp'], - type=event_dict['type'], - group_id=event_dict['group_id'], - data=event_dict['data'], - ) + def iter_from_event(self, key: int) -> Iterator[BaseEvent]: + if key < 0: + raise ValueError(f'event.id \'{key}\' must be non-negative') + + it = self._db.itervalues(self._cf_event) + it.seek(int_to_bytes(key, 8)) + + for event_bytes in it: + yield BaseEvent.parse_raw(event_bytes) def _db_get_last_event(self) -> Optional[BaseEvent]: last_element: Optional[bytes] = None @@ -52,7 +51,7 @@ def _db_get_last_event(self) -> Optional[BaseEvent]: for i in it: last_element = i break - return None if last_element is None else self._load_from_bytes(last_element) + return None if last_element is None else BaseEvent.parse_raw(last_element) def _db_get_last_group_id(self) -> Optional[int]: last_group_id = self._db.get((self._cf_meta, _KEY_LAST_GROUP_ID)) @@ -61,12 +60,10 @@ def _db_get_last_group_id(self) -> Optional[int]: return int.from_bytes(last_group_id, byteorder='big', signed=False) def save_event(self, event: BaseEvent) -> None: - if event.id < 0: - raise ValueError('event.id must be non-negative') if (self._last_event is None and event.id != 0) or \ (self._last_event is not None and event.id > self._last_event.id + 1): raise ValueError('invalid event.id, ids must be sequential and leave no gaps') - event_data = json_dumpb(event.__dict__) + event_data = json_dumpb(event.dict()) key = int_to_bytes(event.id, 8) self._db.put((self._cf_event, key), event_data) self._last_event = event @@ -76,11 +73,11 @@ def save_event(self, event: BaseEvent) -> None: def get_event(self, key: int) -> Optional[BaseEvent]: if key < 0: - raise ValueError('key must be non-negative') + raise ValueError(f'event.id \'{key}\' must be non-negative') event = self._db.get((self._cf_event, int_to_bytes(key, 8))) if event is None: return None - return self._load_from_bytes(event_data=event) + return BaseEvent.parse_raw(event) def get_last_event(self) -> Optional[BaseEvent]: return self._last_event diff --git a/hathor/event/websocket/__init__.py b/hathor/event/websocket/__init__.py new file mode 100644 index 000000000..b831cff42 --- /dev/null +++ b/hathor/event/websocket/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2022 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hathor.event.websocket.factory import EventWebsocketFactory +from hathor.event.websocket.protocol import EventWebsocketProtocol + +__all__ = ['EventWebsocketFactory', 'EventWebsocketProtocol'] diff --git a/hathor/event/websocket/factory.py b/hathor/event/websocket/factory.py new file mode 100644 index 000000000..697302393 --- /dev/null +++ b/hathor/event/websocket/factory.py @@ -0,0 +1,99 @@ +# Copyright 2023 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Set + +from autobahn.twisted.websocket import WebSocketServerFactory +from structlog import get_logger + +from hathor.event import BaseEvent +from hathor.event.storage import EventStorage +from hathor.event.websocket.protocol import EventWebsocketProtocol +from hathor.event.websocket.response import EventResponse, InvalidRequestType +from hathor.util import Reactor + +logger = get_logger() + + +class EventWebsocketFactory(WebSocketServerFactory): + """ Websocket that will handle events + """ + + protocol = EventWebsocketProtocol + _is_running = False + _latest_event_id: Optional[int] = None + + def __init__(self, reactor: Reactor, event_storage: EventStorage): + super().__init__() + self.log = logger.new() + self._reactor = reactor + self._event_storage = event_storage + self._connections: Set[EventWebsocketProtocol] = set() + + latest_event = self._event_storage.get_last_event() + + if latest_event is not None: + self._latest_event_id = latest_event.id + + def start(self): + """Start the WebSocket server. Required to be able to send events.""" + self._is_running = True + + def stop(self): + """Stop the WebSocket server. No events can be sent.""" + self._is_running = False + + for connection in self._connections: + connection.sendClose() + + self._connections.clear() + + def broadcast_event(self, event: BaseEvent) -> None: + """Broadcast the event to each registered client.""" + self._latest_event_id = event.id + + for connection in self._connections: + self._send_event_to_connection(connection, event) + + def register(self, connection: EventWebsocketProtocol) -> None: + """Registers a client. Called when a ws connection is opened (after handshaking).""" + if not self._is_running: + return connection.send_invalid_request_response(InvalidRequestType.EVENT_WS_NOT_RUNNING) + + self.log.info('registering connection', client_peer=connection.client_peer) + + self._connections.add(connection) + + def unregister(self, connection: EventWebsocketProtocol) -> None: + """Unregisters a client. Called when a ws connection is closed.""" + self.log.info('unregistering connection', client_peer=connection.client_peer) + self._connections.discard(connection) + + def send_next_event_to_connection(self, connection: EventWebsocketProtocol) -> None: + next_event_id = connection.next_expected_event_id() + + if not connection.can_receive_event(next_event_id): + return + + if event := self._event_storage.get_event(next_event_id): + self._send_event_to_connection(connection, event) + self._reactor.callLater(0, self.send_next_event_to_connection, connection) + + def _send_event_to_connection(self, connection: EventWebsocketProtocol, event: BaseEvent) -> None: + if not connection.can_receive_event(event.id): + return + + response = EventResponse(event=event, latest_event_id=self._latest_event_id) + + connection.send_event_response(response) diff --git a/hathor/event/websocket/protocol.py b/hathor/event/websocket/protocol.py new file mode 100644 index 000000000..26c70891b --- /dev/null +++ b/hathor/event/websocket/protocol.py @@ -0,0 +1,186 @@ +# Copyright 2022 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, Dict, Optional, Type + +from autobahn.exception import Disconnected +from autobahn.twisted.websocket import WebSocketServerProtocol +from autobahn.websocket import ConnectionRequest +from pydantic import ValidationError +from structlog import get_logger + +from hathor.event.websocket.request import AckRequest, Request, RequestWrapper, StartStreamRequest, StopStreamRequest +from hathor.event.websocket.response import EventResponse, InvalidRequestResponse, InvalidRequestType, Response +from hathor.util import json_dumpb + +if TYPE_CHECKING: + from hathor.event.websocket import EventWebsocketFactory + +logger = get_logger() + + +class EventWebsocketProtocol(WebSocketServerProtocol): + """ Websocket protocol, basically forwards some events to the Websocket factory. + """ + + factory: 'EventWebsocketFactory' + client_peer: Optional[str] = None + + _last_sent_event_id: Optional[int] = None + _ack_event_id: Optional[int] = None + _window_size: int = 0 + _stream_is_active: bool = False + + def __init__(self): + super().__init__() + self.log = logger.new() + + def can_receive_event(self, event_id: int) -> bool: + """Returns whether this client is available to receive an event.""" + number_of_pending_events = 0 + + if self._last_sent_event_id is not None: + ack_offset = -1 if self._ack_event_id is None else self._ack_event_id + number_of_pending_events = self._last_sent_event_id - ack_offset + + return ( + self._stream_is_active + and event_id == self.next_expected_event_id() + and number_of_pending_events < self._window_size + ) + + def next_expected_event_id(self) -> int: + """Returns the ID of the next event the client expects.""" + return 0 if self._last_sent_event_id is None else self._last_sent_event_id + 1 + + def onConnect(self, request: ConnectionRequest) -> None: + self.client_peer = request.peer + self.log = self.log.new(client_peer=self.client_peer) + self.log.info('connection opened to the event websocket, starting handshake...') + + def onOpen(self) -> None: + self.log.info('connection established to the event websocket') + self.factory.register(self) + + def onClose(self, wasClean: bool, code: int, reason: str) -> None: + self.log.info('connection closed to the event websocket', reason=reason) + self.factory.unregister(self) + + def onMessage(self, payload: bytes, isBinary: bool) -> None: + self.log.debug('message', payload=payload.hex() if isBinary else payload.decode('utf8')) + + try: + request = RequestWrapper.parse_raw_request(payload) + self._handle_request(request) + except ValidationError as error: + self.send_invalid_request_response(InvalidRequestType.VALIDATION_ERROR, payload, str(error)) + except InvalidRequestError as error: + self.send_invalid_request_response(error.type, payload) + + def _handle_request(self, request: Request) -> None: + # This could be a pattern match in Python 3.10 + request_type = type(request) + handlers: Dict[Type, Callable] = { + StartStreamRequest: self._handle_start_stream_request, + AckRequest: self._handle_ack_request, + StopStreamRequest: lambda _: self._handle_stop_stream_request() + } + handle_fn = handlers.get(request_type) + + assert handle_fn is not None, f'cannot handle request of unknown type "{request_type}"' + + handle_fn(request) + + def _handle_start_stream_request(self, request: StartStreamRequest) -> None: + if self._stream_is_active: + raise InvalidRequestError(InvalidRequestType.STREAM_IS_ACTIVE) + + self._validate_ack(request.last_ack_event_id) + + self._last_sent_event_id = request.last_ack_event_id + self._ack_event_id = request.last_ack_event_id + self._window_size = request.window_size + self._stream_is_active = True + + self.factory.send_next_event_to_connection(self) + + def _handle_ack_request(self, request: AckRequest) -> None: + if not self._stream_is_active: + raise InvalidRequestError(InvalidRequestType.STREAM_IS_INACTIVE) + + self._validate_ack(request.ack_event_id) + + self._ack_event_id = request.ack_event_id + self._window_size = request.window_size + + self.factory.send_next_event_to_connection(self) + + def _handle_stop_stream_request(self) -> None: + if not self._stream_is_active: + raise InvalidRequestError(InvalidRequestType.STREAM_IS_INACTIVE) + + self._stream_is_active = False + + def _validate_ack(self, ack_event_id: Optional[int]) -> None: + """Validates an ack_event_id from a request. + + The ack_event_id can't be smaller than the last ack we've received + and can't be larger than the last event we've sent. + """ + if self._ack_event_id is not None and ( + ack_event_id is None or ack_event_id < self._ack_event_id + ): + raise InvalidRequestError(InvalidRequestType.ACK_TOO_SMALL) + + if ack_event_id is not None and ( + self._last_sent_event_id is None or self._last_sent_event_id < ack_event_id + ): + raise InvalidRequestError(InvalidRequestType.ACK_TOO_LARGE) + + def send_event_response(self, event_response: EventResponse) -> None: + self._send_response(event_response) + self._last_sent_event_id = event_response.event.id + + def send_invalid_request_response( + self, + _type: InvalidRequestType, + invalid_payload: Optional[bytes] = None, + error_message: Optional[str] = None + ) -> None: + invalid_request = None if invalid_payload is None else invalid_payload.decode('utf8') + response = InvalidRequestResponse( + type=_type, + invalid_request=invalid_request, + error_message=error_message + ) + + self._send_response(response) + + def _send_response(self, response: Response) -> None: + payload = json_dumpb(response.dict()) + + try: + self.sendMessage(payload) + except Disconnected: + # Connection is closed. Nothing to do. + pass + # XXX: unfortunately autobahn can raise 3 different exceptions and one of them is a bare Exception + # https://github.com/crossbario/autobahn-python/blob/v20.12.3/autobahn/websocket/protocol.py#L2201-L2294 + except Exception: + self.log.error('send failed, moving on', exc_info=True) + + +class InvalidRequestError(Exception): + def __init__(self, _type: InvalidRequestType): + self.type = _type diff --git a/hathor/event/websocket/request.py b/hathor/event/websocket/request.py new file mode 100644 index 000000000..c4c5efd04 --- /dev/null +++ b/hathor/event/websocket/request.py @@ -0,0 +1,68 @@ +# Copyright 2023 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, Optional, Union + +from pydantic import NonNegativeInt + +from hathor.utils.pydantic import BaseModel + + +class StartStreamRequest(BaseModel): + """Class that represents a client request to start streaming events. + + Args: + type: The type of the request. + last_ack_event_id: The ID of the last event acknowledged by the client. + window_size: The amount of events the client is able to process. + """ + type: Literal['START_STREAM'] + last_ack_event_id: Optional[NonNegativeInt] + window_size: NonNegativeInt + + +class AckRequest(BaseModel): + """Class that represents a client request to ack and event and change the window size. + + Args: + type: The type of the request. + ack_event_id: The ID of the last event acknowledged by the client. + window_size: The amount of events the client is able to process. + """ + type: Literal['ACK'] + ack_event_id: NonNegativeInt + window_size: NonNegativeInt + + +class StopStreamRequest(BaseModel): + """Class that represents a client request to stop streaming events. + + Args: + type: The type of the request. + """ + type: Literal['STOP_STREAM'] + + +# This could be more performatic in Python 3.9: +# Request = Annotated[StartStreamRequest | AckRequest | StopStreamRequest, Field(discriminator='type')] +Request = Union[StartStreamRequest, AckRequest, StopStreamRequest] + + +class RequestWrapper(BaseModel): + """Class that wraps the Request union type for parsing.""" + __root__: Request + + @classmethod + def parse_raw_request(cls, raw: bytes) -> Request: + return cls.parse_raw(raw).__root__ diff --git a/hathor/event/websocket/response.py b/hathor/event/websocket/response.py new file mode 100644 index 000000000..bb23e0e5c --- /dev/null +++ b/hathor/event/websocket/response.py @@ -0,0 +1,61 @@ +# Copyright 2023 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from typing import Optional + +from pydantic import Field, NonNegativeInt + +from hathor.event import BaseEvent +from hathor.utils.pydantic import BaseModel + + +class Response(BaseModel): + pass + + +class EventResponse(Response): + """Class that represents an event to be sent to the client. + + Args: + type: The type of the response. + event: The event. + latest_event_id: The ID of the latest event known by the server. + """ + + type: str = Field(default='EVENT', const=True) + event: BaseEvent + latest_event_id: NonNegativeInt + + +class InvalidRequestType(Enum): + EVENT_WS_NOT_RUNNING = 'EVENT_WS_NOT_RUNNING' + STREAM_IS_ACTIVE = 'STREAM_IS_ACTIVE' + STREAM_IS_INACTIVE = 'STREAM_IS_INACTIVE' + VALIDATION_ERROR = 'VALIDATION_ERROR' + ACK_TOO_SMALL = 'ACK_TOO_SMALL' + ACK_TOO_LARGE = 'ACK_TOO_LARGE' + + +class InvalidRequestResponse(Response, use_enum_values=True): + """Class to let the client know that it performed an invalid request. + + Args: + type: The type of the response. + invalid_request: The request that was invalid. + error_message: A message describing why the request was invalid. + """ + + type: InvalidRequestType + invalid_request: Optional[str] + error_message: Optional[str] diff --git a/hathor/manager.py b/hathor/manager.py index baf9199b9..f18600a41 100644 --- a/hathor/manager.py +++ b/hathor/manager.py @@ -30,7 +30,6 @@ from hathor.conf import HathorSettings from hathor.consensus import ConsensusAlgorithm from hathor.event.event_manager import EventManager -from hathor.event.storage import EventStorage from hathor.exception import ( DoubleSpendingError, HathorError, @@ -84,7 +83,7 @@ class UnhealthinessReason(str, Enum): def __init__(self, reactor: Reactor, *, pubsub: PubSubManager, peer_id: Optional[PeerId] = None, network: Optional[str] = None, hostname: Optional[str] = None, wallet: Optional[BaseWallet] = None, tx_storage: Optional[TransactionStorage] = None, - event_storage: Optional[EventStorage] = None, + event_manager: Optional[EventManager] = None, stratum_port: Optional[int] = None, ssl: bool = True, enable_sync_v1: bool = True, enable_sync_v2: bool = False, capabilities: Optional[List[str]] = None, checkpoints: Optional[List[Checkpoint]] = None, @@ -158,10 +157,9 @@ def __init__(self, reactor: Reactor, *, pubsub: PubSubManager, peer_id: Optional self.pubsub = pubsub self.tx_storage = tx_storage self.tx_storage.pubsub = self.pubsub - self.event_manager: Optional[EventManager] = None - if event_storage is not None: - self.event_manager = EventManager(event_storage, self.reactor, not_none(self.my_peer.id)) - self.event_manager.subscribe(self.pubsub) + + self._event_manager = event_manager + if enable_sync_v2: assert self.tx_storage.indexes is not None self.log.debug('enable sync-v2 indexes') @@ -311,6 +309,9 @@ def start(self) -> None: if self.stratum_factory: self.stratum_factory.start() + if self._event_manager: + self._event_manager.start(not_none(self.my_peer.id)) + # Start running self.tx_storage.start_running_manager() @@ -342,6 +343,9 @@ def stop(self) -> Deferred: if wait_stratum: waits.append(wait_stratum) + if self._event_manager: + self._event_manager.stop() + self.tx_storage.flush() return defer.DeferredList(waits) diff --git a/tests/event/test_base_event.py b/tests/event/test_base_event.py new file mode 100644 index 000000000..26e24a58e --- /dev/null +++ b/tests/event/test_base_event.py @@ -0,0 +1,67 @@ +# Copyright 2023 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import ValidationError + +from hathor.event import BaseEvent + + +@pytest.mark.parametrize('event_id', [0, 1, 1000]) +@pytest.mark.parametrize('group_id', [None, 0, 1, 1000]) +def test_create_base_event(event_id, group_id): + event = BaseEvent( + peer_id='some_peer', + id=event_id, + timestamp=123.3, + type='some_type', + data=dict(some_data='some_value'), + group_id=group_id + ) + + expected = dict( + peer_id='some_peer', + id=event_id, + timestamp=123.3, + type='some_type', + data=dict(some_data='some_value'), + group_id=group_id + ) + + assert event.dict() == expected + + +@pytest.mark.parametrize('event_id', [None, -1, -1000]) +def test_create_base_event_fail_id(event_id): + with pytest.raises(ValidationError): + BaseEvent( + peer_id='some_peer', + id=event_id, + timestamp=123.3, + type='some_type', + data=dict(some_data='some_value') + ) + + +@pytest.mark.parametrize('group_id', [-1, -1000]) +def test_create_base_event_fail_group_id(group_id): + with pytest.raises(ValidationError): + BaseEvent( + peer_id='some_peer', + id=0, + timestamp=123.3, + type='some_type', + data=dict(some_data='some_value'), + group_id=group_id + ) diff --git a/tests/event/test_event_manager.py b/tests/event/test_event_manager.py index a90e9e6b8..2e38bf913 100644 --- a/tests/event/test_event_manager.py +++ b/tests/event/test_event_manager.py @@ -1,5 +1,9 @@ +from unittest.mock import Mock + +from hathor.event import EventManager from hathor.event.storage.memory_storage import EventMemoryStorage -from hathor.pubsub import HathorEvents +from hathor.event.websocket import EventWebsocketFactory +from hathor.pubsub import HathorEvents, PubSubManager from tests import unittest @@ -9,9 +13,20 @@ class BaseEventManagerTest(unittest.TestCase): def setUp(self): super().setUp() self.event_storage = EventMemoryStorage() + self.event_ws_factory = Mock(spec_set=EventWebsocketFactory) self.network = 'testnet' - self.manager = self.create_peer(self.network, event_storage=self.event_storage) - self.event_manager = self.manager.event_manager + pubsub = PubSubManager(self.clock) + self.event_manager = EventManager( + event_storage=self.event_storage, + event_ws_factory=self.event_ws_factory, + pubsub=pubsub, + reactor=self.clock + ) + self.manager = self.create_peer( + self.network, + event_manager=self.event_manager, + pubsub=pubsub + ) def test_if_event_is_persisted(self): block = self.manager.tx_storage.get_best_block() diff --git a/tests/event/test_event_reorg.py b/tests/event/test_event_reorg.py index 53dcb02ed..6f0b52a97 100644 --- a/tests/event/test_event_reorg.py +++ b/tests/event/test_event_reorg.py @@ -1,4 +1,10 @@ +from unittest.mock import Mock + from hathor.conf import HathorSettings +from hathor.event import EventManager +from hathor.event.storage import EventMemoryStorage +from hathor.event.websocket import EventWebsocketFactory +from hathor.pubsub import PubSubManager from tests import unittest from tests.utils import add_new_blocks, get_genesis_key @@ -11,9 +17,20 @@ class BaseEventReorgTest(unittest.TestCase): def setUp(self): super().setUp() self.network = 'testnet' - self.manager = self.create_peer(self.network, event_storage=True) - self.event_manager = self.manager.event_manager - self.event_storage = self.event_manager.event_storage + self.event_ws_factory = Mock(spec_set=EventWebsocketFactory) + self.event_storage = EventMemoryStorage() + pubsub = PubSubManager(self.clock) + self.event_manager = EventManager( + event_storage=self.event_storage, + event_ws_factory=self.event_ws_factory, + pubsub=pubsub, + reactor=self.clock + ) + self.manager = self.create_peer( + self.network, + event_manager=self.event_manager, + pubsub=pubsub + ) # read genesis keys self.genesis_private_key = get_genesis_key() diff --git a/tests/event/test_event_storage.py b/tests/event/test_event_storage.py index aa99b9276..f2441f3a5 100644 --- a/tests/event/test_event_storage.py +++ b/tests/event/test_event_storage.py @@ -23,10 +23,15 @@ def test_save_event_and_retrieve(self): assert event_retrieved == event - def test_get_key_nonpositive(self): - with self.assertRaises(ValueError): + def test_get_negative_key(self): + with self.assertRaises(ValueError) as cm: self.event_storage.get_event(-1) + self.assertEqual( + 'event.id \'-1\' must be non-negative', + str(cm.exception) + ) + def test_get_nonexistent_event(self): assert self.event_storage.get_event(0) is None assert self.event_storage.get_event(9999) is None @@ -41,15 +46,61 @@ def test_save_events_and_retrieve_the_last(self): assert event_retrieved.id == last_event.id def test_save_non_sequential(self): - last_event = None for i in range(10): - last_event = self.event_mocker.generate_mocked_event(i) - self.event_storage.save_event(last_event) + event = self.event_mocker.generate_mocked_event(i) + self.event_storage.save_event(event) + + non_sequential_event = self.event_mocker.generate_mocked_event(100) - non_sequential_event = self.event_mocker.generate_mocked_event(11) - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: self.event_storage.save_event(non_sequential_event) + self.assertEqual( + 'invalid event.id, ids must be sequential and leave no gaps', + str(cm.exception) + ) + + def test_iter_from_event_empty(self): + self._test_iter_from_event(0) + + def test_iter_from_event_single(self): + self._test_iter_from_event(1) + + def test_iter_from_event_multiple(self): + self._test_iter_from_event(20) + + def _test_iter_from_event(self, n_events): + expected_events = [] + for i in range(n_events): + event = self.event_mocker.generate_mocked_event(i) + expected_events.append(event) + self.event_storage.save_event(event) + + actual_events = list(self.event_storage.iter_from_event(0)) + + self.assertEqual(expected_events, actual_events) + + def test_iter_from_event_negative_key(self): + with self.assertRaises(ValueError) as cm: + events = self.event_storage.iter_from_event(-10) + list(events) + + self.assertEqual( + 'event.id \'-10\' must be non-negative', + str(cm.exception) + ) + + def test_save_events_and_retrieve_last_group_id(self): + expected_group_id = 4 + for i in range(10): + group_id = i if i <= expected_group_id else None + event = self.event_mocker.generate_mocked_event(i, group_id) + self.event_storage.save_event(event) + + actual_group_id = self.event_storage.get_last_group_id() + + assert expected_group_id == actual_group_id + @pytest.mark.skipif(not HAS_ROCKSDB, reason='requires python-rocksdb') class EventStorageRocksDBTest(EventStorageBaseTest): diff --git a/tests/event/websocket/__init__.py b/tests/event/websocket/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/event/websocket/test_factory.py b/tests/event/websocket/test_factory.py new file mode 100644 index 000000000..2d788a93d --- /dev/null +++ b/tests/event/websocket/test_factory.py @@ -0,0 +1,157 @@ +# Copyright 2023 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock, call + +import pytest + +from hathor.event import BaseEvent +from hathor.event.storage import EventMemoryStorage +from hathor.event.websocket.factory import EventWebsocketFactory +from hathor.event.websocket.protocol import EventWebsocketProtocol +from hathor.event.websocket.response import EventResponse, InvalidRequestType +from hathor.simulator.clock import HeapClock + + +def test_started_register(): + factory = _get_factory() + connection = Mock(spec_set=EventWebsocketProtocol) + connection.send_invalid_request_response = Mock() + + factory.start() + factory.register(connection) + + connection.send_invalid_request_response.assert_not_called() + + +def test_non_started_register(): + factory = _get_factory() + connection = Mock(spec_set=EventWebsocketProtocol) + connection.send_invalid_request_response = Mock() + + factory.register(connection) + + connection.send_invalid_request_response.assert_called_once_with(InvalidRequestType.EVENT_WS_NOT_RUNNING) + + +def test_stopped_register(): + factory = _get_factory() + connection = Mock(spec_set=EventWebsocketProtocol) + connection.send_invalid_request_response = Mock() + + factory.start() + factory.stop() + factory.register(connection) + + connection.send_invalid_request_response.assert_called_once_with(InvalidRequestType.EVENT_WS_NOT_RUNNING) + + +@pytest.mark.parametrize('can_receive_event', [False, True]) +def test_broadcast_event(can_receive_event: bool) -> None: + n_starting_events = 10 + factory = _get_factory(n_starting_events) + event = _create_event(n_starting_events - 1) + connection = Mock(spec_set=EventWebsocketProtocol) + connection.can_receive_event = Mock(return_value=can_receive_event) + connection.send_event_response = Mock() + + factory.start() + factory.register(connection) + factory.broadcast_event(event) + + if not can_receive_event: + return connection.send_event_response.assert_not_called() + + response = EventResponse(event=event, latest_event_id=n_starting_events - 1) + connection.send_event_response.assert_called_once_with(response) + + +def test_broadcast_multiple_events_multiple_connections(): + factory = _get_factory(10) + connection1 = Mock(spec_set=EventWebsocketProtocol) + connection1.can_receive_event = Mock(return_value=True) + connection1.send_event_response = Mock() + connection2 = Mock(spec_set=EventWebsocketProtocol) + connection2.can_receive_event = Mock(return_value=True) + connection2.send_event_response = Mock() + + factory.start() + factory.register(connection1) + factory.register(connection2) + + for event_id in range(10): + event = _create_event(event_id) + factory.broadcast_event(event) + + assert connection1.send_event_response.call_count == 10 + assert connection2.send_event_response.call_count == 10 + + +@pytest.mark.parametrize( + ['next_expected_event_id', 'can_receive_event'], + [ + (0, False), + (0, True), + (3, True), + (10, True) + ] +) +def test_send_next_event_to_connection(next_expected_event_id: int, can_receive_event: bool) -> None: + n_starting_events = 10 + clock = HeapClock() + factory = _get_factory(n_starting_events, clock) + connection = Mock(spec_set=EventWebsocketProtocol) + connection.send_event_response = Mock() + connection.can_receive_event = Mock(return_value=can_receive_event) + connection.next_expected_event_id = Mock( + side_effect=lambda: next_expected_event_id + connection.send_event_response.call_count + ) + + factory.start() + factory.register(connection) + factory.send_next_event_to_connection(connection) + + clock.advance(0) + + if not can_receive_event or next_expected_event_id > n_starting_events - 1: + return connection.send_event_response.assert_not_called() + + calls = [] + for _id in range(next_expected_event_id, n_starting_events): + event = _create_event(_id) + response = EventResponse(event=event, latest_event_id=n_starting_events - 1) + calls.append(call(response)) + + assert connection.send_event_response.call_count == n_starting_events - next_expected_event_id + connection.send_event_response.assert_has_calls(calls) + + +def _get_factory(n_starting_events: int = 0, clock: HeapClock = HeapClock()) -> EventWebsocketFactory: + event_storage = EventMemoryStorage() + + for event_id in range(n_starting_events): + event = _create_event(event_id) + event_storage.save_event(event) + + return EventWebsocketFactory(clock, event_storage) + + +def _create_event(event_id: int) -> BaseEvent: + return BaseEvent( + peer_id='123', + id=event_id, + timestamp=123456, + type='type', + data={} + ) diff --git a/tests/event/websocket/test_protocol.py b/tests/event/websocket/test_protocol.py new file mode 100644 index 000000000..9012e5980 --- /dev/null +++ b/tests/event/websocket/test_protocol.py @@ -0,0 +1,342 @@ +# Copyright 2023 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from unittest.mock import ANY, Mock + +import pytest +from autobahn.websocket import ConnectionRequest + +from hathor.event import BaseEvent +from hathor.event.websocket import EventWebsocketFactory +from hathor.event.websocket.protocol import EventWebsocketProtocol +from hathor.event.websocket.response import EventResponse, InvalidRequestType + + +@pytest.fixture +def factory(): + return Mock(spec_set=EventWebsocketFactory) + + +def test_init(): + protocol = EventWebsocketProtocol() + + assert protocol.client_peer is None + assert protocol._last_sent_event_id is None + assert protocol._ack_event_id is None + assert protocol._window_size == 0 + assert not protocol._stream_is_active + + +def test_next_expected_event_id(): + protocol = EventWebsocketProtocol() + + assert protocol.next_expected_event_id() == 0 + + protocol._last_sent_event_id = 5 + + assert protocol.next_expected_event_id() == 6 + + +def test_on_connect(): + protocol = EventWebsocketProtocol() + request = Mock(spec_set=ConnectionRequest) + request.peer = 'some_peer' + + protocol.onConnect(request) + + assert protocol.client_peer == 'some_peer' + + +def test_on_open(factory): + protocol = EventWebsocketProtocol() + protocol.factory = factory + + protocol.onOpen() + + factory.register.assert_called_once_with(protocol) + + +def test_on_close(factory): + protocol = EventWebsocketProtocol() + protocol.factory = factory + + protocol.onClose(True, 1, 'reason') + + factory.unregister.assert_called_once_with(protocol) + + +def test_send_event_response(): + protocol = EventWebsocketProtocol() + protocol.sendMessage = Mock() + response = EventResponse( + event=BaseEvent( + peer_id='some_peer_id', + id=10, + timestamp=123, + type='some_type', + data={} + ), + latest_event_id=10 + ) + + protocol.send_event_response(response) + + expected_payload = b'{"type":"EVENT","event":{"peer_id":"some_peer_id","id":10,"timestamp":123.0,' \ + b'"type":"some_type","data":{},"group_id":null},"latest_event_id":10}' + + protocol.sendMessage.assert_called_once_with(expected_payload) + + +@pytest.mark.parametrize('_type', [InvalidRequestType.VALIDATION_ERROR, InvalidRequestType.STREAM_IS_INACTIVE]) +@pytest.mark.parametrize('invalid_payload', [None, b'some_payload']) +@pytest.mark.parametrize('error_message', [None, 'some error']) +def test_send_invalid_request_response(_type, invalid_payload, error_message): + protocol = EventWebsocketProtocol() + protocol.sendMessage = Mock() + + protocol.send_invalid_request_response(_type, invalid_payload, error_message) + + invalid_request = "null" if invalid_payload is None else f'"{invalid_payload.decode("utf8")}"' + error_message = "null" if error_message is None else f'"{error_message}"' + expected_payload = f'{{"type":"{_type.value}","invalid_request":{invalid_request},' \ + f'"error_message":{error_message}}}' + + protocol.sendMessage.assert_called_once_with(expected_payload.encode('utf8')) + + +@pytest.mark.parametrize( + [ + 'last_sent_event_id', + 'ack_event_id', + 'window_size', + 'stream_is_active', + 'event_id', + 'expected_result', + ], + [ + (None, None, 0, False, 0, False), + (None, None, 0, True, 0, False), + (None, None, 1, True, 0, True), + (0, None, 1, False, 1, False), + (0, None, 1, True, 1, False), + (0, 0, 1, False, 1, False), + (0, 0, 1, True, 1, True), + (1, 0, 1, True, 2, False), + (1, 0, 2, False, 2, False), + (1, 0, 2, True, 2, True), + (2, 2, 3, True, 3, True), + (3, 2, 3, True, 4, True), + (4, 2, 3, False, 5, False), + (4, 2, 3, True, 5, True), + (4, 2, 3, True, 4, False), + (5, 2, 3, True, 6, False), + ] +) +def test_can_receive_event( + last_sent_event_id: Optional[int], + ack_event_id: Optional[int], + window_size: int, + stream_is_active: bool, + event_id: int, + expected_result: bool +) -> None: + protocol = EventWebsocketProtocol() + protocol._last_sent_event_id = last_sent_event_id + protocol._ack_event_id = ack_event_id + protocol._window_size = window_size + protocol._stream_is_active = stream_is_active + + result = protocol.can_receive_event(event_id) + + assert result == expected_result + + +def test_on_valid_stop_message(): + protocol = EventWebsocketProtocol() + protocol._stream_is_active = True + + protocol.onMessage(b'{"type": "STOP_STREAM"}', False) + + assert not protocol._stream_is_active + + +def test_stop_message_on_inactive(): + protocol = EventWebsocketProtocol() + protocol.sendMessage = Mock() + protocol._stream_is_active = False + payload = b'{"type": "STOP_STREAM"}' + + protocol.onMessage(payload, False) + + response = b'{"type":"STREAM_IS_INACTIVE","invalid_request":"{\\"type\\": \\"STOP_STREAM\\"}",' \ + b'"error_message":null}' + protocol.sendMessage.assert_called_once_with(response) + assert not protocol._stream_is_active + + +@pytest.mark.parametrize( + ['ack_event_id', 'window_size', 'last_sent_event_id'], + [ + (0, 0, 0), + (0, 1, 10), + (0, 10, 1), + (1, 0, 1000), + (10, 0, 10), + ] +) +def test_on_valid_ack_message(ack_event_id, window_size, last_sent_event_id): + protocol = EventWebsocketProtocol() + protocol._last_sent_event_id = last_sent_event_id + protocol.factory = Mock() + protocol.factory.send_next_event_to_connection = Mock() + protocol._stream_is_active = True + payload = f'{{"type": "ACK", "ack_event_id": {ack_event_id}, "window_size": {window_size}}}'.encode('utf8') + + protocol.onMessage(payload, False) + + assert protocol._ack_event_id == ack_event_id + assert protocol._window_size == window_size + protocol.factory.send_next_event_to_connection.assert_called_once() + + +@pytest.mark.parametrize( + ['ack_event_id', 'window_size', 'last_sent_event_id'], + [ + (0, 0, 0), + (0, 1, 10), + (0, 10, 1), + (1, 0, 1000), + (10, 0, 10), + ] +) +def test_on_valid_start_message(ack_event_id, window_size, last_sent_event_id): + protocol = EventWebsocketProtocol() + protocol._last_sent_event_id = last_sent_event_id + protocol.factory = Mock() + protocol.factory.send_next_event_to_connection = Mock() + payload = f'{{"type": "START_STREAM", "last_ack_event_id": {ack_event_id}, "window_size": {window_size}}}' + + protocol.onMessage(payload.encode('utf8'), False) + + assert protocol._ack_event_id == ack_event_id + assert protocol._window_size == window_size + assert protocol._last_sent_event_id == ack_event_id + assert protocol._stream_is_active + protocol.factory.send_next_event_to_connection.assert_called_once() + + +def test_ack_message_on_inactive(): + protocol = EventWebsocketProtocol() + protocol.sendMessage = Mock() + protocol._stream_is_active = False + payload = b'{"type": "ACK", "ack_event_id": 10, "window_size": 10}' + + protocol.onMessage(payload, False) + + response = b'{"type":"STREAM_IS_INACTIVE","invalid_request":"{\\"type\\": \\"ACK\\", \\"ack_event_id\\": 10, ' \ + b'\\"window_size\\": 10}","error_message":null}' + protocol.sendMessage.assert_called_once_with(response) + + +def test_start_message_on_active(): + protocol = EventWebsocketProtocol() + protocol.sendMessage = Mock() + protocol._stream_is_active = True + payload = b'{"type": "START_STREAM", "last_ack_event_id": 10, "window_size": 10}' + + protocol.onMessage(payload, False) + + response = b'{"type":"STREAM_IS_ACTIVE","invalid_request":"{\\"type\\": \\"START_STREAM\\", ' \ + b'\\"last_ack_event_id\\": 10, \\"window_size\\": 10}","error_message":null}' + protocol.sendMessage.assert_called_once_with(response) + + +@pytest.mark.parametrize( + ['_ack_event_id', 'last_sent_event_id', 'ack_event_id', '_type'], + [ + (1, None, 0, InvalidRequestType.ACK_TOO_SMALL), + (1, 1, 0, InvalidRequestType.ACK_TOO_SMALL), + (10, None, 5, InvalidRequestType.ACK_TOO_SMALL), + (10, 1, 5, InvalidRequestType.ACK_TOO_SMALL), + (0, None, 1, InvalidRequestType.ACK_TOO_LARGE), + (0, 0, 1, InvalidRequestType.ACK_TOO_LARGE), + (5, None, 10, InvalidRequestType.ACK_TOO_LARGE), + (5, 1, 10, InvalidRequestType.ACK_TOO_LARGE), + ] +) +def test_on_invalid_ack_message(_ack_event_id, last_sent_event_id, ack_event_id, _type): + protocol = EventWebsocketProtocol() + protocol._ack_event_id = _ack_event_id + protocol._last_sent_event_id = last_sent_event_id + protocol.send_invalid_request_response = Mock() + protocol._stream_is_active = True + payload = f'{{"type": "ACK", "ack_event_id": {ack_event_id}, "window_size": 0}}'.encode('utf8') + + protocol.onMessage(payload, False) + + protocol.send_invalid_request_response.assert_called_once_with(_type, payload) + + +@pytest.mark.parametrize( + ['_ack_event_id', 'last_sent_event_id', 'ack_event_id', '_type'], + [ + (0, None, None, InvalidRequestType.ACK_TOO_SMALL), + (0, 1, None, InvalidRequestType.ACK_TOO_SMALL), + (1, None, 0, InvalidRequestType.ACK_TOO_SMALL), + (1, 1, 0, InvalidRequestType.ACK_TOO_SMALL), + (10, None, 5, InvalidRequestType.ACK_TOO_SMALL), + (10, 1, 5, InvalidRequestType.ACK_TOO_SMALL), + (None, None, 0, InvalidRequestType.ACK_TOO_LARGE), + (1, 0, 1, InvalidRequestType.ACK_TOO_LARGE), + (0, None, 1, InvalidRequestType.ACK_TOO_LARGE), + (0, 0, 1, InvalidRequestType.ACK_TOO_LARGE), + (5, None, 10, InvalidRequestType.ACK_TOO_LARGE), + (5, 1, 10, InvalidRequestType.ACK_TOO_LARGE), + ] +) +def test_on_invalid_start_message(_ack_event_id, last_sent_event_id, ack_event_id, _type): + protocol = EventWebsocketProtocol() + protocol._ack_event_id = _ack_event_id + protocol._last_sent_event_id = last_sent_event_id + protocol.send_invalid_request_response = Mock() + ack_event_id = 'null' if ack_event_id is None else ack_event_id + payload = f'{{"type": "START_STREAM", "last_ack_event_id": {ack_event_id}, "window_size": 0}}'.encode('utf8') + + protocol.onMessage(payload, False) + + protocol.send_invalid_request_response.assert_called_once_with(_type, payload) + + +@pytest.mark.parametrize( + 'payload', + [ + b'{"type": "FAKE_TYPE"}', + b'{"type": "STOP_STREAM", "fake_prop": 123}', + b'{"type": "START_STREAM", "last_ack_event_id": "wrong value", "window_size": 10}', + b'{"type": "START_STREAM", "last_ack_event_id": 0, "window_size": -10}', + b'{"type": "START_STREAM", "last_ack_event_id": -10, "window_size": 0}', + b'{"type": "ACK", "ack_event_id": 0, "window_size": "wrong value"}', + b'{"type": "ACK", "ack_event_id": 0, "window_size": -10}', + b'{"type": "ACK", "ack_event_id": -10, "window_size": 0}', + ] +) +def test_validation_error_on_message(payload): + protocol = EventWebsocketProtocol() + protocol.send_invalid_request_response = Mock() + protocol._stream_is_active = False + + protocol.onMessage(payload, False) + + protocol.send_invalid_request_response.assert_called_once_with(InvalidRequestType.VALIDATION_ERROR, payload, ANY) diff --git a/tests/others/test_builder.py b/tests/others/test_builder.py index a1ab2a1b5..3f0b12fcd 100644 --- a/tests/others/test_builder.py +++ b/tests/others/test_builder.py @@ -3,6 +3,9 @@ import pytest from hathor.builder import CliBuilder +from hathor.event import EventManager +from hathor.event.storage import EventMemoryStorage, EventRocksDBStorage +from hathor.event.websocket import EventWebsocketFactory from hathor.exception import BuilderError from hathor.indexes import MemoryIndexesManager, RocksDBIndexesManager from hathor.manager import HathorManager @@ -52,6 +55,7 @@ def test_all_default(self): self.assertNotIn(SyncVersion.V2, manager.connections._sync_factories) self.assertFalse(self.builder._build_prometheus) self.assertFalse(self.builder._build_status) + self.assertIsNone(manager._event_manager) @pytest.mark.skipif(not HAS_ROCKSDB, reason='requires python-rocksdb') def test_cache_storage(self): @@ -139,3 +143,19 @@ def test_memory_and_rocksdb_indexes(self): data_dir = self.mkdtemp() args = ['--memory-indexes', '--x-rocksdb-indexes', '--data', data_dir] self._build_with_error(args, 'You cannot use --memory-indexes and --x-rocksdb-indexes.') + + @pytest.mark.skipif(not HAS_ROCKSDB, reason='requires python-rocksdb') + def test_event_queue_with_rocksdb_storage(self): + data_dir = self.mkdtemp() + manager = self._build(['--x-enable-event-queue', '--rocksdb-storage', '--data', data_dir]) + + self.assertIsInstance(manager._event_manager, EventManager) + self.assertIsInstance(manager._event_manager._event_storage, EventRocksDBStorage) + self.assertIsInstance(manager._event_manager._event_ws_factory, EventWebsocketFactory) + + def test_event_queue_with_memory_storage(self): + manager = self._build(['--x-enable-event-queue', '--memory-storage']) + + self.assertIsInstance(manager._event_manager, EventManager) + self.assertIsInstance(manager._event_manager._event_storage, EventMemoryStorage) + self.assertIsInstance(manager._event_manager._event_ws_factory, EventWebsocketFactory) diff --git a/tests/unittest.py b/tests/unittest.py index ae095370a..f0efcafc3 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -130,7 +130,8 @@ def _create_test_wallet(self): def create_peer(self, network, peer_id=None, wallet=None, tx_storage=None, unlock_wallet=True, wallet_index=False, capabilities=None, full_verification=True, enable_sync_v1=None, enable_sync_v2=None, - checkpoints=None, utxo_index=False, event_storage=None, use_memory_index=None, start_manager=True): + checkpoints=None, utxo_index=False, event_manager=None, use_memory_index=None, start_manager=True, + pubsub=None): if enable_sync_v1 is None: assert hasattr(self, '_enable_sync_v1'), ('`_enable_sync_v1` has no default by design, either set one on ' 'the test class or pass `enable_sync_v1` by argument') @@ -159,7 +160,7 @@ def create_peer(self, network, peer_id=None, wallet=None, tx_storage=None, unloc self._pending_cleanups.append(rocksdb_storage.close) tx_storage = TransactionRocksDBStorage(rocksdb_storage, use_memory_indexes=use_memory_index) - pubsub = PubSubManager(self.clock) + pubsub = pubsub or PubSubManager(self.clock) builder = CliBuilder() if wallet_index: @@ -168,17 +169,6 @@ def create_peer(self, network, peer_id=None, wallet=None, tx_storage=None, unloc if utxo_index: tx_storage.indexes.enable_utxo_index() - if event_storage is True: - # XXX: either bool or Optional[EventStorage] is accepted for event_storage - if self.use_memory_storage: - from hathor.event.storage import EventMemoryStorage - event_storage = EventMemoryStorage() - else: - from hathor.event.storage import EventRocksDBStorage - event_storage = EventRocksDBStorage(rocksdb_storage) - elif event_storage is False: - event_storage = None - manager = HathorManager( self.clock, pubsub=pubsub, @@ -186,7 +176,7 @@ def create_peer(self, network, peer_id=None, wallet=None, tx_storage=None, unloc network=network, wallet=wallet, tx_storage=tx_storage, - event_storage=event_storage, + event_manager=event_manager, capabilities=capabilities, rng=self.rng, enable_sync_v1=enable_sync_v1, diff --git a/tests/utils.py b/tests/utils.py index 75e05c52f..4ef46cfb1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -659,18 +659,18 @@ def gen_next_id(self) -> int: self.next_id += 1 return next_id - def generate_mocked_event(self, id: Optional[int] = None) -> BaseEvent: - """ Generates a mocked event with a best block found message + def generate_mocked_event(self, event_id: Optional[int] = None, group_id: Optional[int] = None) -> BaseEvent: + """ Generates a mocked event with the best block found message """ - hash = hashlib.sha256(self.generate_random_word(10).encode('utf-8')) - peer_id_mock = hash.hexdigest() + _hash = hashlib.sha256(self.generate_random_word(10).encode('utf-8')) + peer_id_mock = _hash.hexdigest() return BaseEvent( - id=id or self.gen_next_id(), + id=event_id or self.gen_next_id(), peer_id=peer_id_mock, timestamp=1658892990, type='network:best_block_found', - group_id=0, + group_id=group_id, data={ "data": "test" },