From 8303e17f5c7cb4822ddddec4c0dd77ea05330ab8 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 17 Nov 2022 10:07:27 +0100 Subject: [PATCH] Observe changes to internal YDoc to broadcast and store --- ypy_websocket/__init__.py | 1 - ypy_websocket/websocket_provider.py | 5 ++- ypy_websocket/websocket_server.py | 51 ++++++++++++++-------- ypy_websocket/ydoc.py | 66 ----------------------------- ypy_websocket/yutils.py | 52 ++++++++--------------- 5 files changed, 54 insertions(+), 121 deletions(-) delete mode 100644 ypy_websocket/ydoc.py diff --git a/ypy_websocket/__init__.py b/ypy_websocket/__init__.py index 6f66dc0..e217214 100644 --- a/ypy_websocket/__init__.py +++ b/ypy_websocket/__init__.py @@ -1,6 +1,5 @@ from .websocket_provider import WebsocketProvider # noqa from .websocket_server import WebsocketServer, YRoom # noqa -from .ydoc import YDoc # noqa from .yutils import YMessageType # noqa __version__ = "0.4.0" diff --git a/ypy_websocket/websocket_provider.py b/ypy_websocket/websocket_provider.py index 9002554..32c1e42 100644 --- a/ypy_websocket/websocket_provider.py +++ b/ypy_websocket/websocket_provider.py @@ -4,7 +4,7 @@ import y_py as Y -from .yutils import process_message, put_updates, sync +from .yutils import YMessageType, process_sync_message, put_updates, sync class WebsocketProvider: @@ -24,7 +24,8 @@ async def _run(self): await sync(self._ydoc, self._websocket) send_task = asyncio.create_task(self._send()) async for message in self._websocket: - await process_message(message, self._ydoc, self._websocket, self.log) + if message[0] == YMessageType.SYNC: + await process_sync_message(message[1:], self._ydoc, self._websocket, self.log) send_task.cancel() async def _send(self): diff --git a/ypy_websocket/websocket_server.py b/ypy_websocket/websocket_server.py index 5d905d6..153e047 100644 --- a/ypy_websocket/websocket_server.py +++ b/ypy_websocket/websocket_server.py @@ -1,27 +1,28 @@ import asyncio import logging +from functools import partial from typing import Callable, Dict, List, Optional +import y_py as Y + from .awareness import Awareness -from .ydoc import YDoc from .ystore import BaseYStore -from .yutils import sync, update +from .yutils import YMessageType, create_update_message, process_sync_message, put_updates, sync class YRoom: clients: List - ydoc: YDoc + ydoc: Y.YDoc ystore: Optional[BaseYStore] _on_message: Optional[Callable] _update_queue: asyncio.Queue _ready: bool def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None): - self._update_queue = asyncio.Queue() - self.ydoc = YDoc() - self.ydoc.init(self._update_queue) # FIXME: overriding Y.YDoc.__init__ doesn't seem to work + self.ydoc = Y.YDoc() self.awareness = Awareness(self.ydoc) + self._update_queue = asyncio.Queue() self._ready = False self.ready = ready self.ystore = ystore @@ -38,7 +39,7 @@ def ready(self) -> bool: def ready(self, value: bool) -> None: self._ready = value if value: - self.ydoc.ready = True + self.ydoc.observe_after_transaction(partial(put_updates, self._update_queue, self.ydoc)) @property def on_message(self) -> Optional[Callable]: @@ -51,12 +52,17 @@ def on_message(self, value: Optional[Callable]): async def _broadcast_updates(self): while True: update = await self._update_queue.get() - # broadcast internal ydoc's update made from the backend to all clients + # broadcast internal ydoc's update to all clients, that includes changes from the + # clients and changes from the backend (out-of-band changes) for client in self.clients: self.log.debug( - "Sending Y update from backend to client with endpoint: %s", client.path + "Sending Y update to client with endpoint: %s", client.path ) - asyncio.create_task(client.send(update)) + message = create_update_message(update) + asyncio.create_task(client.send(message)) + if self.ystore: + self.log.debug("Writing Y update to YStore") + asyncio.create_task(self.ystore.write(update)) def _clean(self): self._broadcast_task.cancel() @@ -110,16 +116,27 @@ async def serve(self, websocket): skip = await room.on_message(message) if skip: continue - # update our internal state and the YStore (if any) - asyncio.create_task(update(message, room, websocket, self.log)) - # forward messages from this client to every other client in the background - for client in [c for c in room.clients if c != websocket]: + message_type = message[0] + if message_type == YMessageType.SYNC: + # update our internal state in the background + # changes to the internal state are then forwarded to all clients + # and stored in the YStore (if any) + asyncio.create_task(process_sync_message(message[1:], room.ydoc, websocket, self.log)) + elif message_type == YMessageType.AWARENESS: + # forward awareness messages from this client to all clients, + # including itself, because it's used to keep the connection alive self.log.debug( - "Sending Y update from client with endpoint %s to client with endpoint: %s", + "Received %s message from endpoint: %s", + YMessageType.AWARENESS.raw_str(), websocket.path, - client.path, ) - asyncio.create_task(client.send(message)) + for client in room.clients: + self.log.debug( + "Sending Y awareness from client with endpoint %s to client with endpoint: %s", + websocket.path, + client.path, + ) + asyncio.create_task(client.send(message)) # remove this client room.clients = [c for c in room.clients if c != websocket] if self.auto_clean_rooms and not room.clients: diff --git a/ypy_websocket/ydoc.py b/ypy_websocket/ydoc.py deleted file mode 100644 index f30b9d9..0000000 --- a/ypy_websocket/ydoc.py +++ /dev/null @@ -1,66 +0,0 @@ -import asyncio -from types import TracebackType -from typing import Optional, Type - -import y_py as Y - -from .yutils import create_update_message - - -class YDoc(Y.YDoc): - """A YDoc with a custom transaction context manager that allows to put updates into a queue. - `Y.YDoc.observe_after_transaction` catches all updates, while updates made through this YDoc - can be processed separately, which is useful for e.g. updates made only from the backend. - """ - - _begin_transaction = Y.YDoc.begin_transaction - _update_queue: asyncio.Queue - _ready: bool - - def init(self, update_queue: asyncio.Queue): - self._ready = False - self._update_queue = update_queue - - def begin_transaction(self): - return Transaction(self, self._update_queue, self._ready) - - @property - def ready(self) -> bool: - return self._ready - - @ready.setter - def ready(self, value: bool) -> None: - self._ready = value - - -class Transaction: - - ydoc: YDoc - update_queue: asyncio.Queue - state: bytes - transaction: Y.YTransaction - ready: bool - - def __init__(self, ydoc: YDoc, update_queue: asyncio.Queue, ready: bool): - self.ydoc = ydoc - self.update_queue = update_queue - self.ready = ready - - def __enter__(self): - if self.ready: - self.state = Y.encode_state_vector(self.ydoc) - self.transaction = self.ydoc._begin_transaction() - return self.transaction.__enter__() - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> bool: - res = self.transaction.__exit__(exc_type, exc_value, exc_tb) # type: ignore - if self.ready: - update = Y.encode_state_as_update(self.ydoc, self.state) - message = create_update_message(update) - self.update_queue.put_nowait(message) - return res diff --git a/ypy_websocket/yutils.py b/ypy_websocket/yutils.py index ddc2895..c4249fa 100644 --- a/ypy_websocket/yutils.py +++ b/ypy_websocket/yutils.py @@ -101,54 +101,36 @@ def read_var_string(self): def put_updates(update_queue: asyncio.Queue, ydoc: Y.YDoc, event: Y.AfterTransactionEvent) -> None: - message = create_update_message(event.get_update()) - update_queue.put_nowait(message) + update_queue.put_nowait(event.get_update()) -async def process_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> Optional[bytes]: +async def process_sync_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> None: message_type = message[0] + msg = message[1:] log.debug( "Received %s message from endpoint: %s", - YMessageType(message_type).raw_str(), + YSyncMessageType(message_type).raw_str(), websocket.path, ) - if message_type == YMessageType.SYNC: - message_type = message[1] - msg = message[2:] + if message_type == YSyncMessageType.SYNC_STEP1: + state = read_message(msg) + update = Y.encode_state_as_update(ydoc, state) + reply = create_sync_step2_message(update) log.debug( - "Received %s message from endpoint: %s", - YSyncMessageType(message_type).raw_str(), + "Sending %s message to endpoint: %s", + YSyncMessageType.SYNC_STEP2.raw_str(), websocket.path, ) - if message_type == YSyncMessageType.SYNC_STEP1: - state = read_message(msg) - update = Y.encode_state_as_update(ydoc, state) - reply = create_sync_step2_message(update) - log.debug( - "Sending %s message to endpoint: %s", - YSyncMessageType.SYNC_STEP2.raw_str(), - websocket.path, - ) - await websocket.send(reply) - elif message_type in ( - YSyncMessageType.SYNC_STEP2, - YSyncMessageType.SYNC_UPDATE, - ): - update = read_message(msg) - Y.apply_update(ydoc, update) - return update - - return None + await websocket.send(reply) + elif message_type in ( + YSyncMessageType.SYNC_STEP2, + YSyncMessageType.SYNC_UPDATE, + ): + update = read_message(msg) + Y.apply_update(ydoc, update) async def sync(ydoc: Y.YDoc, websocket): state = Y.encode_state_vector(ydoc) msg = create_sync_step1_message(state) await websocket.send(msg) - - -async def update(message, room, websocket, log): - yupdate = await process_message(message, room.ydoc, websocket, log) - if room.ystore and yupdate: - log.debug("Writing Y update to YStore from endpoint: %s", websocket.path) - await room.ystore.write(yupdate)