Skip to content

Commit

Permalink
Add sync protocol (v1) (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Jun 25, 2024
1 parent fb07272 commit 2bfb5a5
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ classifiers = [
test = [
"pytest >=7.4.2,<8",
"y-py >=0.7.0a1,<0.8",
"anyio >=4.4.0,<5",
"trio >=0.25.1,<0.26",
"pydantic >=2.5.2,<3",
"mypy",
"coverage[toml] >=7",
Expand Down
3 changes: 3 additions & 0 deletions python/pycrdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from ._map import MapEvent as MapEvent
from ._pycrdt import Subscription as Subscription
from ._pycrdt import TransactionEvent as TransactionEvent
from ._sync import create_sync_message as create_sync_message
from ._sync import create_update_message as create_update_message
from ._sync import handle_sync_message as handle_sync_message
from ._text import Text as Text
from ._text import TextEvent as TextEvent
from ._transaction import ReadTransaction as ReadTransaction
Expand Down
122 changes: 122 additions & 0 deletions python/pycrdt/_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from enum import IntEnum

from pycrdt import Doc


class YMessageType(IntEnum):
SYNC = 0
AWARENESS = 1


class YSyncMessageType(IntEnum):
SYNC_STEP1 = 0
SYNC_STEP2 = 1
SYNC_UPDATE = 2


def write_var_uint(num: int) -> bytes:
res = []
while num > 127:
res.append(128 | (127 & num))
num >>= 7
res.append(num)
return bytes(res)


def create_message(data: bytes, msg_type: int) -> bytes:
return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data


def create_sync_step1_message(data: bytes) -> bytes:
return create_message(data, YSyncMessageType.SYNC_STEP1)


def create_sync_step2_message(data: bytes) -> bytes:
return create_message(data, YSyncMessageType.SYNC_STEP2)


def create_update_message(data: bytes) -> bytes:
return create_message(data, YSyncMessageType.SYNC_UPDATE)


class Decoder:
def __init__(self, stream: bytes):
self.stream = stream
self.length = len(stream)
self.i0 = 0

def read_var_uint(self) -> int:
if self.length <= 0:
raise RuntimeError("Y protocol error")
uint = 0
i = 0
while True:
byte = self.stream[self.i0]
uint += (byte & 127) << i
i += 7
self.i0 += 1
self.length -= 1
if byte < 128:
break
return uint

def read_message(self) -> bytes | None:
if self.length == 0:
return None
length = self.read_var_uint()
if length == 0:
return b""
i1 = self.i0 + length
message = self.stream[self.i0 : i1]
self.i0 = i1
self.length -= length
return message

def read_messages(self):
while True:
message = self.read_message()
if message is None:
return
yield message

def read_var_string(self):
message = self.read_message()
if message is None:
return ""
return message.decode("utf-8")


def read_message(stream: bytes) -> bytes:
message = Decoder(stream).read_message()
assert message is not None
return message


def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None:
message_type = message[0]
msg = message[1:]

if message_type == YSyncMessageType.SYNC_STEP1:
state = read_message(msg)
update = ydoc.get_update(state)
reply = create_sync_step2_message(update)
return reply

if message_type in (
YSyncMessageType.SYNC_STEP2,
YSyncMessageType.SYNC_UPDATE,
):
update = read_message(msg)
# Ignore empty updates
if update != b"\x00\x00":
ydoc.apply_update(update)

return None


def create_sync_message(ydoc: Doc) -> bytes:
state = ydoc.get_state()
message = create_sync_step1_message(state)
return message
98 changes: 98 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytest
from anyio import TASK_STATUS_IGNORED, create_memory_object_stream, create_task_group, sleep
from anyio.abc import TaskStatus
from pycrdt import (
Array,
Doc,
create_sync_message,
create_update_message,
handle_sync_message,
)
from pycrdt._sync import Decoder, write_var_uint

pytestmark = pytest.mark.anyio


class ConnectedDoc:
def __init__(self):
self.doc = Doc()
self.doc.observe(lambda event: self.send(event.update))
self.connected_docs = []
self.send_stream, self.receive_stream = create_memory_object_stream[bytes](
max_buffer_size=1024
)

async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
async with create_task_group() as tg:
task_status.started()
tg.start_soon(self.process_received_messages)

def connect(self, *connected_docs):
self.connected_docs += connected_docs
sync_message = create_sync_message(self.doc)
for connected_doc in connected_docs:
connected_doc.receive(sync_message, self)

def receive(self, message: bytes, sender=None):
self.send_stream.send_nowait((message, sender))

async def process_received_messages(self):
async for message, sender in self.receive_stream:
reply = handle_sync_message(message[1:], self.doc)
if reply is not None:
sender.receive(reply, self)

def send(self, message: bytes):
for doc in self.connected_docs:
doc.receive(create_update_message(message))


async def test_sync():
async with create_task_group() as tg:
doc0 = ConnectedDoc()
doc1 = ConnectedDoc()

await tg.start(doc0.start)
await tg.start(doc1.start)

doc0.connect(doc1)
doc1.connect(doc0)

array0 = doc0.doc.get("array", type=Array)
array0.append(0)
await sleep(0.1)

array1 = doc1.doc.get("array", type=Array)
assert array1[0] == 0

# doc2 only connects to doc0
# but since doc0 and doc1 are connected,
# doc2 is indirectly connected to doc1
doc2 = ConnectedDoc()
await tg.start(doc2.start)
doc2.connect(doc0)
doc0.connect(doc2)
await sleep(0.1)
array2 = doc2.doc.get("array", type=Array)
assert array2[0] == 0
array2.append(1)
await sleep(0.1)
assert array0[1] == 1
assert array1[1] == 1

tg.cancel_scope.cancel()


def test_write_var_uint():
assert write_var_uint(128) == b"\x80\x01"


def test_decoder():
with pytest.raises(RuntimeError) as exc_info:
Decoder(b"").read_var_uint()
assert str(exc_info.value) == "Y protocol error"

assert list(Decoder(b"").read_messages()) == []
assert list(Decoder(b"\x00").read_messages()) == [b""]
assert Decoder(b"").read_var_string() == ""
assert Decoder(b"\x05Hello").read_var_string() == "Hello"

0 comments on commit 2bfb5a5

Please sign in to comment.