diff --git a/ypy_websocket/__init__.py b/ypy_websocket/__init__.py index e217214..6f66dc0 100644 --- a/ypy_websocket/__init__.py +++ b/ypy_websocket/__init__.py @@ -1,5 +1,6 @@ from .websocket_provider import WebsocketProvider # noqa from .websocket_server import WebsocketServer, YRoom # noqa +from .ydoc import YDoc # noqa from .yutils import YMessageType # noqa __version__ = "0.4.0" diff --git a/ypy_websocket/awareness.py b/ypy_websocket/awareness.py index 244f880..c23c18b 100644 --- a/ypy_websocket/awareness.py +++ b/ypy_websocket/awareness.py @@ -38,7 +38,8 @@ def get_changes(self, message: bytes) -> Dict[str, Any]: if client_id == self.client_id and self.states.get(client_id) is not None: clock += 1 else: - del self.states[client_id] + if client_id in self.states: + del self.states[client_id] else: self.states[client_id] = state self.meta[client_id] = { diff --git a/ypy_websocket/websocket_server.py b/ypy_websocket/websocket_server.py index 6b3d0fd..33484cb 100644 --- a/ypy_websocket/websocket_server.py +++ b/ypy_websocket/websocket_server.py @@ -1,29 +1,31 @@ import asyncio import logging -from functools import partial from typing import Callable, Dict, List, Optional -import y_py as Y - from .awareness import Awareness +from .ydoc import YDoc from .ystore import BaseYStore -from .yutils import put_updates, sync, update +from .yutils import sync, update class YRoom: clients: List - ydoc: Y.YDoc + ydoc: YDoc ystore: Optional[BaseYStore] _on_message: Optional[Callable] _update_queue: asyncio.Queue + _ready: bool - def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None): - self.ydoc = Y.YDoc() - self.awareness = Awareness(self.ydoc) + def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None): self._update_queue = asyncio.Queue() + self.ydoc = YDoc() + self.ydoc.init(self._update_queue) # FIXME: overriding Y.YDoc.__init__ doesn't seem to work + self.awareness = Awareness(self.ydoc) + self._ready = False self.ready = ready self.ystore = ystore + self.log = log or logging.getLogger(__name__) self.clients = [] self._on_message = None self._broadcast_task = asyncio.create_task(self._broadcast_updates()) @@ -36,7 +38,7 @@ def ready(self) -> bool: def ready(self, value: bool) -> None: self._ready = value if value: - self.ydoc.observe_after_transaction(partial(put_updates, self._update_queue, self.ydoc)) + self.ydoc._ready = True @property def on_message(self) -> Optional[Callable]: @@ -47,17 +49,14 @@ def on_message(self, value: Optional[Callable]): self._on_message = value async def _broadcast_updates(self): - try: - while True: - update = await self._update_queue.get() - # broadcast internal ydoc's update to all clients - for client in self.clients: - try: - await client.send(update) - except Exception: - pass - except Exception: - pass + while True: + update = await self._update_queue.get() + # broadcast internal ydoc's update made from the backend to all clients + for client in self.clients: + self.log.debug( + "Sending Y update from backend to client with endpoint: %s", client.path + ) + asyncio.create_task(client.send(update)) def _clean(self): self._broadcast_task.cancel() @@ -76,7 +75,7 @@ def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log= def get_room(self, path: str) -> YRoom: if path not in self.rooms.keys(): - self.rooms[path] = YRoom(ready=self.rooms_ready) + self.rooms[path] = YRoom(ready=self.rooms_ready, log=self.log) return self.rooms[path] def get_room_name(self, room): @@ -113,9 +112,10 @@ async def serve(self, websocket): continue # update our internal state and the YStore (if any) asyncio.create_task(update(message, room, websocket, self.log)) - # forward messages to every other client in the background + # forward messages from this client to every other client in the background for client in [c for c in room.clients if c != websocket]: - self.log.debug("Sending Y update to client with endpoint: %s", client.path) + self.log.debug("Sending Y update from client with endpoint: %s", websocket.path) + self.log.debug("... to client with endpoint: %s", client.path) asyncio.create_task(client.send(message)) # remove this client room.clients = [c for c in room.clients if c != websocket] diff --git a/ypy_websocket/ydoc.py b/ypy_websocket/ydoc.py new file mode 100644 index 0000000..5120bef --- /dev/null +++ b/ypy_websocket/ydoc.py @@ -0,0 +1,58 @@ +import asyncio +from types import TracebackType +from typing import Optional, Type + +import y_py as Y + +from .yutils import create_update_message + + +class YDoc(Y.YDoc): + """A YDoc with a custom transaction context manager that allows to put updates into a queue. + `Y.YDoc.observe_after_transaction` catches all updates, while updates made through this YDoc + can be processed separately, which is useful for e.g. updates made only from the backend. + """ + + _begin_transaction = Y.YDoc.begin_transaction + _update_queue: asyncio.Queue + _ready: bool + + def init(self, update_queue: asyncio.Queue): + self._ready = False + self._update_queue = update_queue + + def begin_transaction(self): + return Transaction(self, self._update_queue, self._ready) + + +class Transaction: + + ydoc: YDoc + update_queue: asyncio.Queue + state: bytes + transaction: Y.YTransaction + ready: bool + + def __init__(self, ydoc: YDoc, update_queue: asyncio.Queue, ready: bool): + self.ydoc = ydoc + self.update_queue = update_queue + self.ready = ready + + def __enter__(self): + if self.ready: + self.state = Y.encode_state_vector(self.ydoc) + self.transaction = self.ydoc._begin_transaction() + return self.transaction.__enter__() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + res = self.transaction.__exit__(exc_type, exc_value, exc_tb) # type: ignore + if self.ready: + update = Y.encode_state_as_update(self.ydoc, self.state) + message = create_update_message(update) + self.update_queue.put_nowait(message) + return res diff --git a/ypy_websocket/ystore.py b/ypy_websocket/ystore.py index 2b68fc8..9c46eb7 100644 --- a/ypy_websocket/ystore.py +++ b/ypy_websocket/ystore.py @@ -1,5 +1,6 @@ import asyncio import tempfile +from abc import ABC, abstractmethod from pathlib import Path from typing import AsyncIterator, Callable, Optional, Tuple @@ -14,19 +15,21 @@ class YDocNotFound(Exception): pass -class BaseYStore: +class BaseYStore(ABC): metadata_callback: Optional[Callable] = None + @abstractmethod def __init__(self, path: str, metadata_callback=None): - raise RuntimeError("Not implemented") + ... + @abstractmethod async def write(self, data: bytes) -> None: - raise RuntimeError("Not implemented") + ... + @abstractmethod async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: - raise RuntimeError("Not implemented") - yield b"", b"" + ... async def get_metadata(self) -> bytes: metadata = b"" if not self.metadata_callback else await self.metadata_callback() @@ -37,30 +40,35 @@ async def encode_state_as_update(self, ydoc: Y.YDoc): await self.write(update) async def apply_updates(self, ydoc: Y.YDoc): - async for update, metadata in self.read(): + async for update, metadata in await self.read(): Y.apply_update(ydoc, update) # type: ignore class FileYStore(BaseYStore): - """A YStore which uses the local file system.""" + """A YStore which uses one file per document.""" path: str + metadata_callback: Optional[Callable] + lock: asyncio.Lock - def __init__(self, path: str, metadata_callback=None): + def __init__(self, path: str, metadata_callback: Optional[Callable] = None): self.path = path self.metadata_callback = metadata_callback - - async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: - try: - async with aiofiles.open(self.path, "rb") as f: - data = await f.read() - except Exception: - raise YDocNotFound + self.lock = asyncio.Lock() + + async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore + async with self.lock: + try: + async with aiofiles.open(self.path, "rb") as f: + data = await f.read() + except BaseException: + raise YDocNotFound is_data = True for d in Decoder(data).read_messages(): if is_data: update = d else: + # yield data and metadata yield update, d is_data = not is_data @@ -71,16 +79,18 @@ async def write(self, data: bytes) -> None: mode = "wb" else: mode = "ab" - async with aiofiles.open(self.path, mode) as f: - data_len = write_var_uint(len(data)) - await f.write(data_len + data) - metadata = await self.get_metadata() - metadata_len = write_var_uint(len(metadata)) - await f.write(metadata_len + metadata) + async with self.lock: + async with aiofiles.open(self.path, mode) as f: + data_len = write_var_uint(len(data)) + await f.write(data_len + data) + metadata = await self.get_metadata() + metadata_len = write_var_uint(len(metadata)) + await f.write(metadata_len + metadata) class TempFileYStore(FileYStore): """A YStore which uses the system's temporary directory. + Files are writen under a common directory. To prefix the directory name (e.g. /tmp/my_prefix_b4whmm7y/): class PrefixTempFileYStore(TempFileYStore): @@ -90,7 +100,7 @@ class PrefixTempFileYStore(TempFileYStore): prefix_dir: Optional[str] = None base_dir: Optional[str] = None - def __init__(self, path: str, metadata_callback=None): + def __init__(self, path: str, metadata_callback: Optional[Callable] = None): full_path = str(Path(self.get_base_dir()) / path) super().__init__(full_path, metadata_callback=metadata_callback) @@ -106,6 +116,8 @@ def make_directory(self): class SQLiteYStore(BaseYStore): """A YStore which uses an SQLite database. + Unlike file-based YStores, the Y updates of all documents are stored in the same database. + Subclass to point to your database file: class MySQLiteYStore(SQLiteYStore): @@ -116,7 +128,7 @@ class MySQLiteYStore(SQLiteYStore): path: str db_created: asyncio.Event - def __init__(self, path: str, metadata_callback=None): + def __init__(self, path: str, metadata_callback: Optional[Callable] = None): self.path = path self.metadata_callback = metadata_callback self.db_created = asyncio.Event() @@ -124,34 +136,31 @@ def __init__(self, path: str, metadata_callback=None): async def create_db(self): async with aiosqlite.connect(self.db_path) as db: - await db.execute("CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB)") + await db.execute( + "CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB, metadata BLOB)" + ) await db.commit() self.db_created.set() - async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: + async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore + await self.db_created.wait() try: async with aiosqlite.connect(self.db_path) as db: async with db.execute( - "SELECT * FROM yupdates WHERE path = ?", (self.path,) + "SELECT yupdate, metadata FROM yupdates WHERE path = ?", (self.path,) ) as cursor: found = False - is_data = True - async for _, d in cursor: + async for update, metadata in cursor: found = True - if is_data: - update = d - else: - yield update, d - is_data = not is_data + yield update, metadata if not found: raise YDocNotFound - except Exception: + except BaseException: raise YDocNotFound async def write(self, data: bytes) -> None: await self.db_created.wait() + metadata = await self.get_metadata() async with aiosqlite.connect(self.db_path) as db: - await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, data)) - metadata = await self.get_metadata() - await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, metadata)) + await db.execute("INSERT INTO yupdates VALUES (?, ?, ?)", (self.path, data, metadata)) await db.commit()