From 2bf447f65fb2b298675fb4ffd027be5d39f8919c Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 24 Jun 2024 22:04:38 +0200 Subject: [PATCH] Add sync protocol (v1) --- pyproject.toml | 2 + python/pycrdt/__init__.py | 3 + python/pycrdt/_sync.py | 122 ++++++++++++++++++++++++++++++++++++++ tests/test_sync.py | 98 ++++++++++++++++++++++++++++++ 4 files changed, 225 insertions(+) create mode 100644 python/pycrdt/_sync.py create mode 100644 tests/test_sync.py diff --git a/pyproject.toml b/pyproject.toml index fe89ff7..af0e737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ classifiers = [ test = [ "pytest >=7.4.2,<8", "y-py >=0.7.0a1,<0.8", + "anyio >=4.4.0,<5", + "trio >=0.25.1,<0.26", "pydantic >=2.5.2,<3", "mypy", "coverage[toml] >=7", diff --git a/python/pycrdt/__init__.py b/python/pycrdt/__init__.py index 90d202e..79d1f60 100644 --- a/python/pycrdt/__init__.py +++ b/python/pycrdt/__init__.py @@ -5,6 +5,9 @@ from ._map import MapEvent as MapEvent from ._pycrdt import Subscription as Subscription from ._pycrdt import TransactionEvent as TransactionEvent +from ._sync import create_sync_message as create_sync_message +from ._sync import create_update_message as create_update_message +from ._sync import handle_sync_message as handle_sync_message from ._text import Text as Text from ._text import TextEvent as TextEvent from ._transaction import ReadTransaction as ReadTransaction diff --git a/python/pycrdt/_sync.py b/python/pycrdt/_sync.py new file mode 100644 index 0000000..1dd2672 --- /dev/null +++ b/python/pycrdt/_sync.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from enum import IntEnum + +from pycrdt import Doc + + +class YMessageType(IntEnum): + SYNC = 0 + AWARENESS = 1 + + +class YSyncMessageType(IntEnum): + SYNC_STEP1 = 0 + SYNC_STEP2 = 1 + SYNC_UPDATE = 2 + + +def write_var_uint(num: int) -> bytes: + res = [] + while num > 127: + res.append(128 | (127 & num)) + num >>= 7 + res.append(num) + return bytes(res) + + +def create_message(data: bytes, msg_type: int) -> bytes: + return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data + + +def create_sync_step1_message(data: bytes) -> bytes: + return create_message(data, YSyncMessageType.SYNC_STEP1) + + +def create_sync_step2_message(data: bytes) -> bytes: + return create_message(data, YSyncMessageType.SYNC_STEP2) + + +def create_update_message(data: bytes) -> bytes: + return create_message(data, YSyncMessageType.SYNC_UPDATE) + + +class Decoder: + def __init__(self, stream: bytes): + self.stream = stream + self.length = len(stream) + self.i0 = 0 + + def read_var_uint(self) -> int: + if self.length <= 0: + raise RuntimeError("Y protocol error") + uint = 0 + i = 0 + while True: + byte = self.stream[self.i0] + uint += (byte & 127) << i + i += 7 + self.i0 += 1 + self.length -= 1 + if byte < 128: + break + return uint + + def read_message(self) -> bytes | None: + if self.length == 0: + return None + length = self.read_var_uint() + if length == 0: + return b"" + i1 = self.i0 + length + message = self.stream[self.i0 : i1] + self.i0 = i1 + self.length -= length + return message + + def read_messages(self): + while True: + message = self.read_message() + if message is None: + return + yield message + + def read_var_string(self): + message = self.read_message() + if message is None: + return "" + return message.decode("utf-8") + + +def read_message(stream: bytes) -> bytes: + message = Decoder(stream).read_message() + assert message is not None + return message + + +def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None: + message_type = message[0] + msg = message[1:] + + if message_type == YSyncMessageType.SYNC_STEP1: + state = read_message(msg) + update = ydoc.get_update(state) + reply = create_sync_step2_message(update) + return reply + + if message_type in ( + YSyncMessageType.SYNC_STEP2, + YSyncMessageType.SYNC_UPDATE, + ): + update = read_message(msg) + # Ignore empty updates + if update != b"\x00\x00": + ydoc.apply_update(update) + + return None + + +def create_sync_message(ydoc: Doc) -> bytes: + state = ydoc.get_state() + message = create_sync_step1_message(state) + return message diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 0000000..51bd756 --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,98 @@ +import pytest +from anyio import TASK_STATUS_IGNORED, create_memory_object_stream, create_task_group, sleep +from anyio.abc import TaskStatus +from pycrdt import ( + Array, + Doc, + create_sync_message, + create_update_message, + handle_sync_message, +) +from pycrdt._sync import Decoder, write_var_uint + +pytestmark = pytest.mark.anyio + + +class ConnectedDoc: + def __init__(self): + self.doc = Doc() + self.doc.observe(lambda event: self.send(event.update)) + self.connected_docs = [] + self.send_stream, self.receive_stream = create_memory_object_stream[bytes]( + max_buffer_size=1024 + ) + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): + async with create_task_group() as tg: + task_status.started() + tg.start_soon(self.process_received_messages) + + def connect(self, *connected_docs): + self.connected_docs += connected_docs + sync_message = create_sync_message(self.doc) + for connected_doc in connected_docs: + connected_doc.receive(sync_message, self) + + def receive(self, message: bytes, sender=None): + self.send_stream.send_nowait((message, sender)) + + async def process_received_messages(self): + async for message, sender in self.receive_stream: + reply = handle_sync_message(message[1:], self.doc) + if reply is not None: + sender.receive(reply, self) + + def send(self, message: bytes): + for doc in self.connected_docs: + doc.receive(create_update_message(message)) + + +async def test_sync(): + async with create_task_group() as tg: + doc0 = ConnectedDoc() + doc1 = ConnectedDoc() + + await tg.start(doc0.start) + await tg.start(doc1.start) + + doc0.connect(doc1) + doc1.connect(doc0) + + array0 = doc0.doc.get("array", type=Array) + array0.append(0) + await sleep(0.1) + + array1 = doc1.doc.get("array", type=Array) + assert array1[0] == 0 + + # doc2 only connects to doc0 + # but since doc0 and doc1 are connected, + # doc2 is indirectly connected to doc1 + doc2 = ConnectedDoc() + await tg.start(doc2.start) + doc2.connect(doc0) + doc0.connect(doc2) + await sleep(0.1) + array2 = doc2.doc.get("array", type=Array) + assert array2[0] == 0 + array2.append(1) + await sleep(0.1) + assert array0[1] == 1 + assert array1[1] == 1 + + tg.cancel_scope.cancel() + + +def test_write_var_uint(): + assert write_var_uint(128) == b"\x80\x01" + + +def test_decoder(): + with pytest.raises(RuntimeError) as exc_info: + Decoder(b"").read_var_uint() + assert str(exc_info.value) == "Y protocol error" + + assert list(Decoder(b"").read_messages()) == [] + assert list(Decoder(b"\x00").read_messages()) == [b""] + assert Decoder(b"").read_var_string() == "" + assert Decoder(b"\x05Hello").read_var_string() == "Hello"