Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata column to SQLiteYStore, fix update broadcast #42

Merged
merged 5 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ypy_websocket/awareness.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def get_changes(self, message: bytes) -> Dict[str, Any]:
if client_id == self.client_id and self.states.get(client_id) is not None:
clock += 1
else:
del self.states[client_id]
if client_id in self.states:
del self.states[client_id]
else:
self.states[client_id] = state
self.meta[client_id] = {
Expand Down
14 changes: 11 additions & 3 deletions ypy_websocket/websocket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

import y_py as Y

from .yutils import process_message, put_updates, sync
from .yutils import (
YMessageType,
create_update_message,
process_sync_message,
put_updates,
sync,
)


class WebsocketProvider:
Expand All @@ -24,13 +30,15 @@ async def _run(self):
await sync(self._ydoc, self._websocket)
send_task = asyncio.create_task(self._send())
async for message in self._websocket:
await process_message(message, self._ydoc, self._websocket, self.log)
if message[0] == YMessageType.SYNC:
await process_sync_message(message[1:], self._ydoc, self._websocket, self.log)
send_task.cancel()

async def _send(self):
while True:
update = await self._update_queue.get()
message = create_update_message(update)
try:
await self._websocket.send(update)
await self._websocket.send(message)
except Exception:
pass
66 changes: 46 additions & 20 deletions ypy_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from .awareness import Awareness
from .ystore import BaseYStore
from .yutils import put_updates, sync, update
from .yutils import (
YMessageType,
create_update_message,
process_sync_message,
put_updates,
sync,
)


class YRoom:
Expand All @@ -17,13 +23,16 @@ class YRoom:
ystore: Optional[BaseYStore]
_on_message: Optional[Callable]
_update_queue: asyncio.Queue
_ready: bool

def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None):
def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None):
self.ydoc = Y.YDoc()
self.awareness = Awareness(self.ydoc)
self._update_queue = asyncio.Queue()
self._ready = False
self.ready = ready
self.ystore = ystore
self.log = log or logging.getLogger(__name__)
self.clients = []
self._on_message = None
self._broadcast_task = asyncio.create_task(self._broadcast_updates())
Expand All @@ -47,17 +56,17 @@ def on_message(self, value: Optional[Callable]):
self._on_message = value

async def _broadcast_updates(self):
try:
while True:
update = await self._update_queue.get()
# broadcast internal ydoc's update to all clients
for client in self.clients:
try:
await client.send(update)
except Exception:
pass
except Exception:
pass
while True:
update = await self._update_queue.get()
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
asyncio.create_task(client.send(message))
if self.ystore:
self.log.debug("Writing Y update to YStore")
asyncio.create_task(self.ystore.write(update))

def _clean(self):
self._broadcast_task.cancel()
Expand All @@ -76,7 +85,7 @@ def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log=

def get_room(self, path: str) -> YRoom:
if path not in self.rooms.keys():
self.rooms[path] = YRoom(ready=self.rooms_ready)
self.rooms[path] = YRoom(ready=self.rooms_ready, log=self.log)
return self.rooms[path]

def get_room_name(self, room):
Expand Down Expand Up @@ -111,12 +120,29 @@ async def serve(self, websocket):
skip = await room.on_message(message)
if skip:
continue
# update our internal state and the YStore (if any)
asyncio.create_task(update(message, room, websocket, self.log))
# forward messages to every other client in the background
for client in [c for c in room.clients if c != websocket]:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
asyncio.create_task(client.send(message))
message_type = message[0]
if message_type == YMessageType.SYNC:
# update our internal state in the background
# changes to the internal state are then forwarded to all clients
# and stored in the YStore (if any)
asyncio.create_task(
process_sync_message(message[1:], room.ydoc, websocket, self.log)
)
elif message_type == YMessageType.AWARENESS:
# forward awareness messages from this client to all clients,
# including itself, because it's used to keep the connection alive
self.log.debug(
"Received %s message from endpoint: %s",
YMessageType.AWARENESS.raw_str(),
websocket.path,
)
for client in room.clients:
self.log.debug(
"Sending Y awareness from client with endpoint %s to client with endpoint: %s",
websocket.path,
client.path,
)
asyncio.create_task(client.send(message))
# remove this client
room.clients = [c for c in room.clients if c != websocket]
if self.auto_clean_rooms and not room.clients:
Expand Down
83 changes: 46 additions & 37 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import AsyncIterator, Callable, Optional, Tuple

Expand All @@ -14,19 +15,21 @@ class YDocNotFound(Exception):
pass


class BaseYStore:
class BaseYStore(ABC):

metadata_callback: Optional[Callable] = None

@abstractmethod
def __init__(self, path: str, metadata_callback=None):
raise RuntimeError("Not implemented")
...

@abstractmethod
async def write(self, data: bytes) -> None:
raise RuntimeError("Not implemented")
...

@abstractmethod
async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
raise RuntimeError("Not implemented")
yield b"", b""
...

async def get_metadata(self) -> bytes:
metadata = b"" if not self.metadata_callback else await self.metadata_callback()
Expand All @@ -37,30 +40,35 @@ async def encode_state_as_update(self, ydoc: Y.YDoc):
await self.write(update)

async def apply_updates(self, ydoc: Y.YDoc):
async for update, metadata in self.read():
async for update, metadata in self.read(): # type: ignore
Y.apply_update(ydoc, update) # type: ignore


class FileYStore(BaseYStore):
"""A YStore which uses the local file system."""
"""A YStore which uses one file per document."""

path: str
metadata_callback: Optional[Callable]
lock: asyncio.Lock

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.path = path
self.metadata_callback = metadata_callback

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except Exception:
raise YDocNotFound
self.lock = asyncio.Lock()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
async with self.lock:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except BaseException:
raise YDocNotFound
is_data = True
for d in Decoder(data).read_messages():
if is_data:
update = d
else:
# yield data and metadata
yield update, d
is_data = not is_data

Expand All @@ -71,16 +79,18 @@ async def write(self, data: bytes) -> None:
mode = "wb"
else:
mode = "ab"
async with aiofiles.open(self.path, mode) as f:
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
metadata_len = write_var_uint(len(metadata))
await f.write(metadata_len + metadata)
async with self.lock:
async with aiofiles.open(self.path, mode) as f:
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
metadata_len = write_var_uint(len(metadata))
await f.write(metadata_len + metadata)


class TempFileYStore(FileYStore):
"""A YStore which uses the system's temporary directory.
Files are writen under a common directory.
To prefix the directory name (e.g. /tmp/my_prefix_b4whmm7y/):

class PrefixTempFileYStore(TempFileYStore):
Expand All @@ -90,7 +100,7 @@ class PrefixTempFileYStore(TempFileYStore):
prefix_dir: Optional[str] = None
base_dir: Optional[str] = None

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
full_path = str(Path(self.get_base_dir()) / path)
super().__init__(full_path, metadata_callback=metadata_callback)

Expand All @@ -106,6 +116,8 @@ def make_directory(self):

class SQLiteYStore(BaseYStore):
"""A YStore which uses an SQLite database.
Unlike file-based YStores, the Y updates of all documents are stored in the same database.

Subclass to point to your database file:

class MySQLiteYStore(SQLiteYStore):
Expand All @@ -116,42 +128,39 @@ class MySQLiteYStore(SQLiteYStore):
path: str
db_created: asyncio.Event

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.path = path
self.metadata_callback = metadata_callback
self.db_created = asyncio.Event()
asyncio.create_task(self.create_db())

async def create_db(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute("CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB)")
await db.execute(
"CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB, metadata BLOB)"
)
await db.commit()
self.db_created.set()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
await self.db_created.wait()
try:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT * FROM yupdates WHERE path = ?", (self.path,)
"SELECT yupdate, metadata FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
found = False
is_data = True
async for _, d in cursor:
async for update, metadata in cursor:
found = True
if is_data:
update = d
else:
yield update, d
is_data = not is_data
yield update, metadata
if not found:
raise YDocNotFound
except Exception:
except BaseException:
raise YDocNotFound

async def write(self, data: bytes) -> None:
await self.db_created.wait()
metadata = await self.get_metadata()
async with aiosqlite.connect(self.db_path) as db:
await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, data))
metadata = await self.get_metadata()
await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, metadata))
await db.execute("INSERT INTO yupdates VALUES (?, ?, ?)", (self.path, data, metadata))
await db.commit()
52 changes: 17 additions & 35 deletions ypy_websocket/yutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,54 +101,36 @@ def read_var_string(self):


def put_updates(update_queue: asyncio.Queue, ydoc: Y.YDoc, event: Y.AfterTransactionEvent) -> None:
message = create_update_message(event.get_update())
update_queue.put_nowait(message)
update_queue.put_nowait(event.get_update())


async def process_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> Optional[bytes]:
async def process_sync_message(message: bytes, ydoc: Y.YDoc, websocket, log) -> None:
message_type = message[0]
msg = message[1:]
log.debug(
"Received %s message from endpoint: %s",
YMessageType(message_type).raw_str(),
YSyncMessageType(message_type).raw_str(),
websocket.path,
)
if message_type == YMessageType.SYNC:
message_type = message[1]
msg = message[2:]
if message_type == YSyncMessageType.SYNC_STEP1:
state = read_message(msg)
update = Y.encode_state_as_update(ydoc, state)
reply = create_sync_step2_message(update)
log.debug(
"Received %s message from endpoint: %s",
YSyncMessageType(message_type).raw_str(),
"Sending %s message to endpoint: %s",
YSyncMessageType.SYNC_STEP2.raw_str(),
websocket.path,
)
if message_type == YSyncMessageType.SYNC_STEP1:
state = read_message(msg)
update = Y.encode_state_as_update(ydoc, state)
reply = create_sync_step2_message(update)
log.debug(
"Sending %s message to endpoint: %s",
YSyncMessageType.SYNC_STEP2.raw_str(),
websocket.path,
)
await websocket.send(reply)
elif message_type in (
YSyncMessageType.SYNC_STEP2,
YSyncMessageType.SYNC_UPDATE,
):
update = read_message(msg)
Y.apply_update(ydoc, update)
return update

return None
await websocket.send(reply)
elif message_type in (
YSyncMessageType.SYNC_STEP2,
YSyncMessageType.SYNC_UPDATE,
):
update = read_message(msg)
Y.apply_update(ydoc, update)


async def sync(ydoc: Y.YDoc, websocket):
state = Y.encode_state_vector(ydoc)
msg = create_sync_step1_message(state)
await websocket.send(msg)


async def update(message, room, websocket, log):
yupdate = await process_message(message, room.ydoc, websocket, log)
if room.ystore and yupdate:
log.debug("Writing Y update to YStore from endpoint: %s", websocket.path)
await room.ystore.write(yupdate)