From aa18801b8156088f821f535bf48a2f02aaff74b5 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 17 Nov 2022 11:39:16 +0100 Subject: [PATCH] Add metadata column to SQLiteYStore, fix update broadcast (#42) * Add metadata column to SQLiteYStore, fix update broadcast * Review * Fix async iterator * Observe changes to internal YDoc to broadcast and store * Fix WebsocketProvider --- ypy_websocket/awareness.py | 3 +- ypy_websocket/websocket_provider.py | 14 +++-- ypy_websocket/websocket_server.py | 66 ++++++++++++++++------- ypy_websocket/ystore.py | 83 ++++++++++++++++------------- ypy_websocket/yutils.py | 52 ++++++------------ 5 files changed, 122 insertions(+), 96 deletions(-) diff --git a/ypy_websocket/awareness.py b/ypy_websocket/awareness.py index 244f880..c23c18b 100644 --- a/ypy_websocket/awareness.py +++ b/ypy_websocket/awareness.py @@ -38,7 +38,8 @@ def get_changes(self, message: bytes) -> Dict[str, Any]: if client_id == self.client_id and self.states.get(client_id) is not None: clock += 1 else: - del self.states[client_id] + if client_id in self.states: + del self.states[client_id] else: self.states[client_id] = state self.meta[client_id] = { diff --git a/ypy_websocket/websocket_provider.py b/ypy_websocket/websocket_provider.py index 9002554..73b89da 100644 --- a/ypy_websocket/websocket_provider.py +++ b/ypy_websocket/websocket_provider.py @@ -4,7 +4,13 @@ import y_py as Y -from .yutils import process_message, put_updates, sync +from .yutils import ( + YMessageType, + create_update_message, + process_sync_message, + put_updates, + sync, +) class WebsocketProvider: @@ -24,13 +30,15 @@ async def _run(self): await sync(self._ydoc, self._websocket) send_task = asyncio.create_task(self._send()) async for message in self._websocket: - await process_message(message, self._ydoc, self._websocket, self.log) + if message[0] == YMessageType.SYNC: + await process_sync_message(message[1:], self._ydoc, self._websocket, self.log) send_task.cancel() async def _send(self): while True: update = await self._update_queue.get() + message = create_update_message(update) try: - await self._websocket.send(update) + await self._websocket.send(message) except Exception: pass diff --git a/ypy_websocket/websocket_server.py b/ypy_websocket/websocket_server.py index 6b3d0fd..4a383d5 100644 --- a/ypy_websocket/websocket_server.py +++ b/ypy_websocket/websocket_server.py @@ -7,7 +7,13 @@ from .awareness import Awareness from .ystore import BaseYStore -from .yutils import put_updates, sync, update +from .yutils import ( + YMessageType, + create_update_message, + process_sync_message, + put_updates, + sync, +) class YRoom: @@ -17,13 +23,16 @@ class YRoom: ystore: Optional[BaseYStore] _on_message: Optional[Callable] _update_queue: asyncio.Queue + _ready: bool - def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None): + def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None): self.ydoc = Y.YDoc() self.awareness = Awareness(self.ydoc) self._update_queue = asyncio.Queue() + self._ready = False self.ready = ready self.ystore = ystore + self.log = log or logging.getLogger(__name__) self.clients = [] self._on_message = None self._broadcast_task = asyncio.create_task(self._broadcast_updates()) @@ -47,17 +56,17 @@ def on_message(self, value: Optional[Callable]): self._on_message = value async def _broadcast_updates(self): - try: - while True: - update = await self._update_queue.get() - # broadcast internal ydoc's update to all clients - for client in self.clients: - try: - await client.send(update) - except Exception: - pass - except Exception: - pass + while True: + update = await self._update_queue.get() + # broadcast internal ydoc's update to all clients, that includes changes from the + # clients and changes from the backend (out-of-band changes) + for client in self.clients: + self.log.debug("Sending Y update to client with endpoint: %s", client.path) + message = create_update_message(update) + asyncio.create_task(client.send(message)) + if self.ystore: + self.log.debug("Writing Y update to YStore") + asyncio.create_task(self.ystore.write(update)) def _clean(self): self._broadcast_task.cancel() @@ -76,7 +85,7 @@ def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log= def get_room(self, path: str) -> YRoom: if path not in self.rooms.keys(): - self.rooms[path] = YRoom(ready=self.rooms_ready) + self.rooms[path] = YRoom(ready=self.rooms_ready, log=self.log) return self.rooms[path] def get_room_name(self, room): @@ -111,12 +120,29 @@ async def serve(self, websocket): skip = await room.on_message(message) if skip: continue - # update our internal state and the YStore (if any) - asyncio.create_task(update(message, room, websocket, self.log)) - # forward messages to every other client in the background - for client in [c for c in room.clients if c != websocket]: - self.log.debug("Sending Y update to client with endpoint: %s", client.path) - asyncio.create_task(client.send(message)) + message_type = message[0] + if message_type == YMessageType.SYNC: + # update our internal state in the background + # changes to the internal state are then forwarded to all clients + # and stored in the YStore (if any) + asyncio.create_task( + process_sync_message(message[1:], room.ydoc, websocket, self.log) + ) + elif message_type == YMessageType.AWARENESS: + # forward awareness messages from this client to all clients, + # including itself, because it's used to keep the connection alive + self.log.debug( + "Received %s message from endpoint: %s", + YMessageType.AWARENESS.raw_str(), + websocket.path, + ) + for client in room.clients: + self.log.debug( + "Sending Y awareness from client with endpoint %s to client with endpoint: %s", + websocket.path, + client.path, + ) + asyncio.create_task(client.send(message)) # remove this client room.clients = [c for c in room.clients if c != websocket] if self.auto_clean_rooms and not room.clients: diff --git a/ypy_websocket/ystore.py b/ypy_websocket/ystore.py index 2b68fc8..75c1a15 100644 --- a/ypy_websocket/ystore.py +++ b/ypy_websocket/ystore.py @@ -1,5 +1,6 @@ import asyncio import tempfile +from abc import ABC, abstractmethod from pathlib import Path from typing import AsyncIterator, Callable, Optional, Tuple @@ -14,19 +15,21 @@ class YDocNotFound(Exception): pass -class BaseYStore: +class BaseYStore(ABC): metadata_callback: Optional[Callable] = None + @abstractmethod def __init__(self, path: str, metadata_callback=None): - raise RuntimeError("Not implemented") + ... + @abstractmethod async def write(self, data: bytes) -> None: - raise RuntimeError("Not implemented") + ... + @abstractmethod async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: - raise RuntimeError("Not implemented") - yield b"", b"" + ... async def get_metadata(self) -> bytes: metadata = b"" if not self.metadata_callback else await self.metadata_callback() @@ -37,30 +40,35 @@ async def encode_state_as_update(self, ydoc: Y.YDoc): await self.write(update) async def apply_updates(self, ydoc: Y.YDoc): - async for update, metadata in self.read(): + async for update, metadata in self.read(): # type: ignore Y.apply_update(ydoc, update) # type: ignore class FileYStore(BaseYStore): - """A YStore which uses the local file system.""" + """A YStore which uses one file per document.""" path: str + metadata_callback: Optional[Callable] + lock: asyncio.Lock - def __init__(self, path: str, metadata_callback=None): + def __init__(self, path: str, metadata_callback: Optional[Callable] = None): self.path = path self.metadata_callback = metadata_callback - - async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: - try: - async with aiofiles.open(self.path, "rb") as f: - data = await f.read() - except Exception: - raise YDocNotFound + self.lock = asyncio.Lock() + + async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore + async with self.lock: + try: + async with aiofiles.open(self.path, "rb") as f: + data = await f.read() + except BaseException: + raise YDocNotFound is_data = True for d in Decoder(data).read_messages(): if is_data: update = d else: + # yield data and metadata yield update, d is_data = not is_data @@ -71,16 +79,18 @@ async def write(self, data: bytes) -> None: mode = "wb" else: mode = "ab" - async with aiofiles.open(self.path, mode) as f: - data_len = write_var_uint(len(data)) - await f.write(data_len + data) - metadata = await self.get_metadata() - metadata_len = write_var_uint(len(metadata)) - await f.write(metadata_len + metadata) + async with self.lock: + async with aiofiles.open(self.path, mode) as f: + data_len = write_var_uint(len(data)) + await f.write(data_len + data) + metadata = await self.get_metadata() + metadata_len = write_var_uint(len(metadata)) + await f.write(metadata_len + metadata) class TempFileYStore(FileYStore): """A YStore which uses the system's temporary directory. + Files are writen under a common directory. To prefix the directory name (e.g. /tmp/my_prefix_b4whmm7y/): class PrefixTempFileYStore(TempFileYStore): @@ -90,7 +100,7 @@ class PrefixTempFileYStore(TempFileYStore): prefix_dir: Optional[str] = None base_dir: Optional[str] = None - def __init__(self, path: str, metadata_callback=None): + def __init__(self, path: str, metadata_callback: Optional[Callable] = None): full_path = str(Path(self.get_base_dir()) / path) super().__init__(full_path, metadata_callback=metadata_callback) @@ -106,6 +116,8 @@ def make_directory(self): class SQLiteYStore(BaseYStore): """A YStore which uses an SQLite database. + Unlike file-based YStores, the Y updates of all documents are stored in the same database. + Subclass to point to your database file: class MySQLiteYStore(SQLiteYStore): @@ -116,7 +128,7 @@ class MySQLiteYStore(SQLiteYStore): path: str db_created: asyncio.Event - def __init__(self, path: str, metadata_callback=None): + def __init__(self, path: str, metadata_callback: Optional[Callable] = None): self.path = path self.metadata_callback = metadata_callback self.db_created = asyncio.Event() @@ -124,34 +136,31 @@ def __init__(self, path: str, metadata_callback=None): async def create_db(self): async with aiosqlite.connect(self.db_path) as db: - await db.execute("CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB)") + await db.execute( + "CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB, metadata BLOB)" + ) await db.commit() self.db_created.set() - async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: + async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore + await self.db_created.wait() try: async with aiosqlite.connect(self.db_path) as db: async with db.execute( - "SELECT * FROM yupdates WHERE path = ?", (self.path,) + "SELECT yupdate, metadata FROM yupdates WHERE path = ?", (self.path,) ) as cursor: found = False - is_data = True - async for _, d in cursor: + async for update, metadata in cursor: found = True - if is_data: - update = d - else: - yield update, d - is_data = not is_data + yield update, metadata if not found: raise YDocNotFound - except Exception: + except BaseException: raise YDocNotFound async def write(self, data: bytes) -> None: await self.db_created.wait() + metadata = await self.get_metadata() async with aiosqlite.connect(self.db_path) as db: - await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, data)) - metadata = await self.get_metadata() - await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, metadata)) + await db.execute("INSERT INTO yupdates VALUES (?, ?, ?)", (self.path, data, metadata)) await db.commit() diff --git a/ypy_websocket/yutils.py b/ypy_websocket/yutils.py index ddc2895..c4249fa 100644 --- a/ypy_websocket/yutils.py +++ b/ypy_websocket/yutils.py @@ -101,54 +101,36 @@ def read_var_string(self): def put_updates(update_queue: asyncio.Queue, ydoc: Y.YDoc, event: Y.AfterTransactionEvent) -> None: - message = create_update_message(event.get_update()) - update_queue.put_nowait(message) + update_queue.put_nowait(event.get_update()) -async def process_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> Optional[bytes]: +async def process_sync_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> None: message_type = message[0] + msg = message[1:] log.debug( "Received %s message from endpoint: %s", - YMessageType(message_type).raw_str(), + YSyncMessageType(message_type).raw_str(), websocket.path, ) - if message_type == YMessageType.SYNC: - message_type = message[1] - msg = message[2:] + if message_type == YSyncMessageType.SYNC_STEP1: + state = read_message(msg) + update = Y.encode_state_as_update(ydoc, state) + reply = create_sync_step2_message(update) log.debug( - "Received %s message from endpoint: %s", - YSyncMessageType(message_type).raw_str(), + "Sending %s message to endpoint: %s", + YSyncMessageType.SYNC_STEP2.raw_str(), websocket.path, ) - if message_type == YSyncMessageType.SYNC_STEP1: - state = read_message(msg) - update = Y.encode_state_as_update(ydoc, state) - reply = create_sync_step2_message(update) - log.debug( - "Sending %s message to endpoint: %s", - YSyncMessageType.SYNC_STEP2.raw_str(), - websocket.path, - ) - await websocket.send(reply) - elif message_type in ( - YSyncMessageType.SYNC_STEP2, - YSyncMessageType.SYNC_UPDATE, - ): - update = read_message(msg) - Y.apply_update(ydoc, update) - return update - - return None + await websocket.send(reply) + elif message_type in ( + YSyncMessageType.SYNC_STEP2, + YSyncMessageType.SYNC_UPDATE, + ): + update = read_message(msg) + Y.apply_update(ydoc, update) async def sync(ydoc: Y.YDoc, websocket): state = Y.encode_state_vector(ydoc) msg = create_sync_step1_message(state) await websocket.send(msg) - - -async def update(message, room, websocket, log): - yupdate = await process_message(message, room.ydoc, websocket, log) - if room.ystore and yupdate: - log.debug("Writing Y update to YStore from endpoint: %s", websocket.path) - await room.ystore.write(yupdate)