Skip to content

Commit

Permalink
Add metadata column to SQLiteYStore, fix update broadcast (#42)
Browse files Browse the repository at this point in the history
* Add metadata column to SQLiteYStore, fix update broadcast

* Review

* Fix async iterator

* Observe changes to internal YDoc to broadcast and store

* Fix WebsocketProvider
  • Loading branch information
davidbrochart authored Nov 17, 2022
1 parent 0d9a54b commit aa18801
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 96 deletions.
3 changes: 2 additions & 1 deletion ypy_websocket/awareness.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def get_changes(self, message: bytes) -> Dict[str, Any]:
if client_id == self.client_id and self.states.get(client_id) is not None:
clock += 1
else:
del self.states[client_id]
if client_id in self.states:
del self.states[client_id]
else:
self.states[client_id] = state
self.meta[client_id] = {
Expand Down
14 changes: 11 additions & 3 deletions ypy_websocket/websocket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import y_py as Y

from .yutils import process_message, put_updates, sync
from .yutils import (
YMessageType,
create_update_message,
process_sync_message,
put_updates,
sync,
)


class WebsocketProvider:
Expand All @@ -24,13 +30,15 @@ async def _run(self):
await sync(self._ydoc, self._websocket)
send_task = asyncio.create_task(self._send())
async for message in self._websocket:
await process_message(message, self._ydoc, self._websocket, self.log)
if message[0] == YMessageType.SYNC:
await process_sync_message(message[1:], self._ydoc, self._websocket, self.log)
send_task.cancel()

async def _send(self):
while True:
update = await self._update_queue.get()
message = create_update_message(update)
try:
await self._websocket.send(update)
await self._websocket.send(message)
except Exception:
pass
66 changes: 46 additions & 20 deletions ypy_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from .awareness import Awareness
from .ystore import BaseYStore
from .yutils import put_updates, sync, update
from .yutils import (
YMessageType,
create_update_message,
process_sync_message,
put_updates,
sync,
)


class YRoom:
Expand All @@ -17,13 +23,16 @@ class YRoom:
ystore: Optional[BaseYStore]
_on_message: Optional[Callable]
_update_queue: asyncio.Queue
_ready: bool

def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None):
def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None):
self.ydoc = Y.YDoc()
self.awareness = Awareness(self.ydoc)
self._update_queue = asyncio.Queue()
self._ready = False
self.ready = ready
self.ystore = ystore
self.log = log or logging.getLogger(__name__)
self.clients = []
self._on_message = None
self._broadcast_task = asyncio.create_task(self._broadcast_updates())
Expand All @@ -47,17 +56,17 @@ def on_message(self, value: Optional[Callable]):
self._on_message = value

async def _broadcast_updates(self):
try:
while True:
update = await self._update_queue.get()
# broadcast internal ydoc's update to all clients
for client in self.clients:
try:
await client.send(update)
except Exception:
pass
except Exception:
pass
while True:
update = await self._update_queue.get()
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
asyncio.create_task(client.send(message))
if self.ystore:
self.log.debug("Writing Y update to YStore")
asyncio.create_task(self.ystore.write(update))

def _clean(self):
self._broadcast_task.cancel()
Expand All @@ -76,7 +85,7 @@ def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log=

def get_room(self, path: str) -> YRoom:
if path not in self.rooms.keys():
self.rooms[path] = YRoom(ready=self.rooms_ready)
self.rooms[path] = YRoom(ready=self.rooms_ready, log=self.log)
return self.rooms[path]

def get_room_name(self, room):
Expand Down Expand Up @@ -111,12 +120,29 @@ async def serve(self, websocket):
skip = await room.on_message(message)
if skip:
continue
# update our internal state and the YStore (if any)
asyncio.create_task(update(message, room, websocket, self.log))
# forward messages to every other client in the background
for client in [c for c in room.clients if c != websocket]:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
asyncio.create_task(client.send(message))
message_type = message[0]
if message_type == YMessageType.SYNC:
# update our internal state in the background
# changes to the internal state are then forwarded to all clients
# and stored in the YStore (if any)
asyncio.create_task(
process_sync_message(message[1:], room.ydoc, websocket, self.log)
)
elif message_type == YMessageType.AWARENESS:
# forward awareness messages from this client to all clients,
# including itself, because it's used to keep the connection alive
self.log.debug(
"Received %s message from endpoint: %s",
YMessageType.AWARENESS.raw_str(),
websocket.path,
)
for client in room.clients:
self.log.debug(
"Sending Y awareness from client with endpoint %s to client with endpoint: %s",
websocket.path,
client.path,
)
asyncio.create_task(client.send(message))
# remove this client
room.clients = [c for c in room.clients if c != websocket]
if self.auto_clean_rooms and not room.clients:
Expand Down
83 changes: 46 additions & 37 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import AsyncIterator, Callable, Optional, Tuple

Expand All @@ -14,19 +15,21 @@ class YDocNotFound(Exception):
pass


class BaseYStore:
class BaseYStore(ABC):

metadata_callback: Optional[Callable] = None

@abstractmethod
def __init__(self, path: str, metadata_callback=None):
raise RuntimeError("Not implemented")
...

@abstractmethod
async def write(self, data: bytes) -> None:
raise RuntimeError("Not implemented")
...

@abstractmethod
async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
raise RuntimeError("Not implemented")
yield b"", b""
...

async def get_metadata(self) -> bytes:
metadata = b"" if not self.metadata_callback else await self.metadata_callback()
Expand All @@ -37,30 +40,35 @@ async def encode_state_as_update(self, ydoc: Y.YDoc):
await self.write(update)

async def apply_updates(self, ydoc: Y.YDoc):
async for update, metadata in self.read():
async for update, metadata in self.read(): # type: ignore
Y.apply_update(ydoc, update) # type: ignore


class FileYStore(BaseYStore):
"""A YStore which uses the local file system."""
"""A YStore which uses one file per document."""

path: str
metadata_callback: Optional[Callable]
lock: asyncio.Lock

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.path = path
self.metadata_callback = metadata_callback

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except Exception:
raise YDocNotFound
self.lock = asyncio.Lock()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
async with self.lock:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except BaseException:
raise YDocNotFound
is_data = True
for d in Decoder(data).read_messages():
if is_data:
update = d
else:
# yield data and metadata
yield update, d
is_data = not is_data

Expand All @@ -71,16 +79,18 @@ async def write(self, data: bytes) -> None:
mode = "wb"
else:
mode = "ab"
async with aiofiles.open(self.path, mode) as f:
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
metadata_len = write_var_uint(len(metadata))
await f.write(metadata_len + metadata)
async with self.lock:
async with aiofiles.open(self.path, mode) as f:
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
metadata_len = write_var_uint(len(metadata))
await f.write(metadata_len + metadata)


class TempFileYStore(FileYStore):
"""A YStore which uses the system's temporary directory.
Files are writen under a common directory.
To prefix the directory name (e.g. /tmp/my_prefix_b4whmm7y/):
class PrefixTempFileYStore(TempFileYStore):
Expand All @@ -90,7 +100,7 @@ class PrefixTempFileYStore(TempFileYStore):
prefix_dir: Optional[str] = None
base_dir: Optional[str] = None

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
full_path = str(Path(self.get_base_dir()) / path)
super().__init__(full_path, metadata_callback=metadata_callback)

Expand All @@ -106,6 +116,8 @@ def make_directory(self):

class SQLiteYStore(BaseYStore):
"""A YStore which uses an SQLite database.
Unlike file-based YStores, the Y updates of all documents are stored in the same database.
Subclass to point to your database file:
class MySQLiteYStore(SQLiteYStore):
Expand All @@ -116,42 +128,39 @@ class MySQLiteYStore(SQLiteYStore):
path: str
db_created: asyncio.Event

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.path = path
self.metadata_callback = metadata_callback
self.db_created = asyncio.Event()
asyncio.create_task(self.create_db())

async def create_db(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute("CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB)")
await db.execute(
"CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB, metadata BLOB)"
)
await db.commit()
self.db_created.set()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
await self.db_created.wait()
try:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT * FROM yupdates WHERE path = ?", (self.path,)
"SELECT yupdate, metadata FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
found = False
is_data = True
async for _, d in cursor:
async for update, metadata in cursor:
found = True
if is_data:
update = d
else:
yield update, d
is_data = not is_data
yield update, metadata
if not found:
raise YDocNotFound
except Exception:
except BaseException:
raise YDocNotFound

async def write(self, data: bytes) -> None:
await self.db_created.wait()
metadata = await self.get_metadata()
async with aiosqlite.connect(self.db_path) as db:
await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, data))
metadata = await self.get_metadata()
await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, metadata))
await db.execute("INSERT INTO yupdates VALUES (?, ?, ?)", (self.path, data, metadata))
await db.commit()
52 changes: 17 additions & 35 deletions ypy_websocket/yutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,54 +101,36 @@ def read_var_string(self):


def put_updates(update_queue: asyncio.Queue, ydoc: Y.YDoc, event: Y.AfterTransactionEvent) -> None:
message = create_update_message(event.get_update())
update_queue.put_nowait(message)
update_queue.put_nowait(event.get_update())


async def process_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> Optional[bytes]:
async def process_sync_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> None:
message_type = message[0]
msg = message[1:]
log.debug(
"Received %s message from endpoint: %s",
YMessageType(message_type).raw_str(),
YSyncMessageType(message_type).raw_str(),
websocket.path,
)
if message_type == YMessageType.SYNC:
message_type = message[1]
msg = message[2:]
if message_type == YSyncMessageType.SYNC_STEP1:
state = read_message(msg)
update = Y.encode_state_as_update(ydoc, state)
reply = create_sync_step2_message(update)
log.debug(
"Received %s message from endpoint: %s",
YSyncMessageType(message_type).raw_str(),
"Sending %s message to endpoint: %s",
YSyncMessageType.SYNC_STEP2.raw_str(),
websocket.path,
)
if message_type == YSyncMessageType.SYNC_STEP1:
state = read_message(msg)
update = Y.encode_state_as_update(ydoc, state)
reply = create_sync_step2_message(update)
log.debug(
"Sending %s message to endpoint: %s",
YSyncMessageType.SYNC_STEP2.raw_str(),
websocket.path,
)
await websocket.send(reply)
elif message_type in (
YSyncMessageType.SYNC_STEP2,
YSyncMessageType.SYNC_UPDATE,
):
update = read_message(msg)
Y.apply_update(ydoc, update)
return update

return None
await websocket.send(reply)
elif message_type in (
YSyncMessageType.SYNC_STEP2,
YSyncMessageType.SYNC_UPDATE,
):
update = read_message(msg)
Y.apply_update(ydoc, update)


async def sync(ydoc: Y.YDoc, websocket):
state = Y.encode_state_vector(ydoc)
msg = create_sync_step1_message(state)
await websocket.send(msg)


async def update(message, room, websocket, log):
yupdate = await process_message(message, room.ydoc, websocket, log)
if room.ystore and yupdate:
log.debug("Writing Y update to YStore from endpoint: %s", websocket.path)
await room.ystore.write(yupdate)

0 comments on commit aa18801

Please sign in to comment.