Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sync protocol (v1) #124

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ classifiers = [
test = [
"pytest >=7.4.2,<8",
"y-py >=0.7.0a1,<0.8",
"anyio >=4.4.0,<5",
"trio >=0.25.1,<0.26",
"pydantic >=2.5.2,<3",
"mypy",
"coverage[toml] >=7",
Expand Down
3 changes: 3 additions & 0 deletions python/pycrdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from ._map import MapEvent as MapEvent
from ._pycrdt import Subscription as Subscription
from ._pycrdt import TransactionEvent as TransactionEvent
from ._sync import create_sync_message as create_sync_message
from ._sync import create_update_message as create_update_message
from ._sync import handle_sync_message as handle_sync_message
from ._text import Text as Text
from ._text import TextEvent as TextEvent
from ._transaction import ReadTransaction as ReadTransaction
Expand Down
122 changes: 122 additions & 0 deletions python/pycrdt/_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from enum import IntEnum

from pycrdt import Doc


class YMessageType(IntEnum):
SYNC = 0
AWARENESS = 1


class YSyncMessageType(IntEnum):
SYNC_STEP1 = 0
SYNC_STEP2 = 1
SYNC_UPDATE = 2


def write_var_uint(num: int) -> bytes:
res = []
while num > 127:
res.append(128 | (127 & num))
num >>= 7
res.append(num)
return bytes(res)


def create_message(data: bytes, msg_type: int) -> bytes:
return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data


def create_sync_step1_message(data: bytes) -> bytes:
return create_message(data, YSyncMessageType.SYNC_STEP1)


def create_sync_step2_message(data: bytes) -> bytes:
return create_message(data, YSyncMessageType.SYNC_STEP2)


def create_update_message(data: bytes) -> bytes:
return create_message(data, YSyncMessageType.SYNC_UPDATE)


class Decoder:
def __init__(self, stream: bytes):
self.stream = stream
self.length = len(stream)
self.i0 = 0

def read_var_uint(self) -> int:
if self.length <= 0:
raise RuntimeError("Y protocol error")
uint = 0
i = 0
while True:
byte = self.stream[self.i0]
uint += (byte & 127) << i
i += 7
self.i0 += 1
self.length -= 1
if byte < 128:
break
return uint

def read_message(self) -> bytes | None:
if self.length == 0:
return None
length = self.read_var_uint()
if length == 0:
return b""
i1 = self.i0 + length
message = self.stream[self.i0 : i1]
self.i0 = i1
self.length -= length
return message

def read_messages(self):
while True:
message = self.read_message()
if message is None:
return
yield message

def read_var_string(self):
message = self.read_message()
if message is None:
return ""
return message.decode("utf-8")


def read_message(stream: bytes) -> bytes:
message = Decoder(stream).read_message()
assert message is not None
return message


def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None:
message_type = message[0]
msg = message[1:]

if message_type == YSyncMessageType.SYNC_STEP1:
state = read_message(msg)
update = ydoc.get_update(state)
reply = create_sync_step2_message(update)
return reply

if message_type in (
YSyncMessageType.SYNC_STEP2,
YSyncMessageType.SYNC_UPDATE,
):
update = read_message(msg)
# Ignore empty updates
if update != b"\x00\x00":
ydoc.apply_update(update)

return None


def create_sync_message(ydoc: Doc) -> bytes:
state = ydoc.get_state()
message = create_sync_step1_message(state)
return message
98 changes: 98 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytest
from anyio import TASK_STATUS_IGNORED, create_memory_object_stream, create_task_group, sleep
from anyio.abc import TaskStatus
from pycrdt import (
Array,
Doc,
create_sync_message,
create_update_message,
handle_sync_message,
)
from pycrdt._sync import Decoder, write_var_uint

pytestmark = pytest.mark.anyio


class ConnectedDoc:
def __init__(self):
self.doc = Doc()
self.doc.observe(lambda event: self.send(event.update))
self.connected_docs = []
self.send_stream, self.receive_stream = create_memory_object_stream[bytes](
max_buffer_size=1024
)

async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
async with create_task_group() as tg:
task_status.started()
tg.start_soon(self.process_received_messages)

def connect(self, *connected_docs):
self.connected_docs += connected_docs
sync_message = create_sync_message(self.doc)
for connected_doc in connected_docs:
connected_doc.receive(sync_message, self)

def receive(self, message: bytes, sender=None):
self.send_stream.send_nowait((message, sender))

async def process_received_messages(self):
async for message, sender in self.receive_stream:
reply = handle_sync_message(message[1:], self.doc)
if reply is not None:
sender.receive(reply, self)

def send(self, message: bytes):
for doc in self.connected_docs:
doc.receive(create_update_message(message))


async def test_sync():
async with create_task_group() as tg:
doc0 = ConnectedDoc()
doc1 = ConnectedDoc()

await tg.start(doc0.start)
await tg.start(doc1.start)

doc0.connect(doc1)
doc1.connect(doc0)

array0 = doc0.doc.get("array", type=Array)
array0.append(0)
await sleep(0.1)

array1 = doc1.doc.get("array", type=Array)
assert array1[0] == 0

# doc2 only connects to doc0
# but since doc0 and doc1 are connected,
# doc2 is indirectly connected to doc1
doc2 = ConnectedDoc()
await tg.start(doc2.start)
doc2.connect(doc0)
doc0.connect(doc2)
await sleep(0.1)
array2 = doc2.doc.get("array", type=Array)
assert array2[0] == 0
array2.append(1)
await sleep(0.1)
assert array0[1] == 1
assert array1[1] == 1

tg.cancel_scope.cancel()


def test_write_var_uint():
assert write_var_uint(128) == b"\x80\x01"


def test_decoder():
with pytest.raises(RuntimeError) as exc_info:
Decoder(b"").read_var_uint()
assert str(exc_info.value) == "Y protocol error"

assert list(Decoder(b"").read_messages()) == []
assert list(Decoder(b"\x00").read_messages()) == [b""]
assert Decoder(b"").read_var_string() == ""
assert Decoder(b"\x05Hello").read_var_string() == "Hello"
Loading