Skip to content

Commit 31975b3

Browse files
committed
Add sync protocol (v1)
1 parent 228ffb9 commit 31975b3

File tree

4 files changed

+210
-0
lines changed

4 files changed

+210
-0
lines changed

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ classifiers = [
3333
test = [
3434
"pytest >=7.4.2,<8",
3535
"y-py >=0.7.0a1,<0.8",
36+
"anyio >=4.4.0,<5",
37+
"trio >=0.25.1,<0.26",
3638
"pydantic >=2.5.2,<3",
3739
"mypy",
3840
"coverage[toml] >=7",

python/pycrdt/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from ._map import MapEvent as MapEvent
66
from ._pycrdt import Subscription as Subscription
77
from ._pycrdt import TransactionEvent as TransactionEvent
8+
from ._sync import create_sync_message as create_sync_message
9+
from ._sync import create_update_message as create_update_message
10+
from ._sync import handle_sync_message as handle_sync_message
811
from ._text import Text as Text
912
from ._text import TextEvent as TextEvent
1013
from ._transaction import ReadTransaction as ReadTransaction

python/pycrdt/_sync.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from __future__ import annotations
2+
3+
from enum import IntEnum
4+
5+
from pycrdt import Doc
6+
7+
8+
class YMessageType(IntEnum):
9+
SYNC = 0
10+
AWARENESS = 1
11+
12+
13+
class YSyncMessageType(IntEnum):
14+
SYNC_STEP1 = 0
15+
SYNC_STEP2 = 1
16+
SYNC_UPDATE = 2
17+
18+
19+
def write_var_uint(num: int) -> bytes:
20+
res = []
21+
while num > 127:
22+
res.append(128 | (127 & num))
23+
num >>= 7
24+
res.append(num)
25+
return bytes(res)
26+
27+
28+
def create_message(data: bytes, msg_type: int) -> bytes:
29+
return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data
30+
31+
32+
def create_sync_step1_message(data: bytes) -> bytes:
33+
return create_message(data, YSyncMessageType.SYNC_STEP1)
34+
35+
36+
def create_sync_step2_message(data: bytes) -> bytes:
37+
return create_message(data, YSyncMessageType.SYNC_STEP2)
38+
39+
40+
def create_update_message(data: bytes) -> bytes:
41+
return create_message(data, YSyncMessageType.SYNC_UPDATE)
42+
43+
44+
class Decoder:
45+
def __init__(self, stream: bytes):
46+
self.stream = stream
47+
self.length = len(stream)
48+
self.i0 = 0
49+
50+
def read_var_uint(self) -> int:
51+
if self.length <= 0:
52+
raise RuntimeError("Y protocol error")
53+
uint = 0
54+
i = 0
55+
while True:
56+
byte = self.stream[self.i0]
57+
uint += (byte & 127) << i
58+
i += 7
59+
self.i0 += 1
60+
self.length -= 1
61+
if byte < 128:
62+
break
63+
return uint
64+
65+
def read_message(self) -> bytes | None:
66+
if self.length == 0:
67+
return None
68+
length = self.read_var_uint()
69+
if length == 0:
70+
return b""
71+
i1 = self.i0 + length
72+
message = self.stream[self.i0 : i1]
73+
self.i0 = i1
74+
self.length -= length
75+
return message
76+
77+
def read_messages(self):
78+
while True:
79+
message = self.read_message()
80+
if message is None:
81+
return
82+
yield message
83+
84+
def read_var_string(self):
85+
message = self.read_message()
86+
if message is None:
87+
return ""
88+
return message.decode("utf-8")
89+
90+
91+
def read_message(stream: bytes) -> bytes:
92+
message = Decoder(stream).read_message()
93+
assert message is not None
94+
return message
95+
96+
97+
def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None:
98+
message_type = message[0]
99+
msg = message[1:]
100+
101+
if message_type == YSyncMessageType.SYNC_STEP1:
102+
state = read_message(msg)
103+
update = ydoc.get_update(state)
104+
reply = create_sync_step2_message(update)
105+
return reply
106+
107+
if message_type in (
108+
YSyncMessageType.SYNC_STEP2,
109+
YSyncMessageType.SYNC_UPDATE,
110+
):
111+
update = read_message(msg)
112+
# Ignore empty updates
113+
if update != b"\x00\x00":
114+
ydoc.apply_update(update)
115+
116+
return None
117+
118+
119+
def create_sync_message(ydoc: Doc) -> bytes:
120+
state = ydoc.get_state()
121+
message = create_sync_step1_message(state)
122+
return message

tests/test_sync.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pytest
2+
from anyio import TASK_STATUS_IGNORED, create_memory_object_stream, create_task_group, sleep
3+
from anyio.abc import TaskStatus
4+
from pycrdt import (
5+
Array,
6+
Doc,
7+
create_sync_message,
8+
create_update_message,
9+
handle_sync_message,
10+
)
11+
from pycrdt._sync import Decoder, write_var_uint
12+
13+
pytestmark = pytest.mark.anyio
14+
15+
16+
class ConnectedDoc:
17+
def __init__(self):
18+
self.doc = Doc()
19+
self.doc.observe(lambda event: self.send(event.update))
20+
self.connected_docs = []
21+
self.send_stream, self.receive_stream = create_memory_object_stream[bytes](
22+
max_buffer_size=1024
23+
)
24+
25+
async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
26+
async with create_task_group() as tg:
27+
tg.start_soon(self.process_received_messages, task_status)
28+
29+
def connect(self, *connected_docs):
30+
self.connected_docs += connected_docs
31+
sync_message = create_sync_message(self.doc)
32+
for connected_doc in connected_docs:
33+
connected_doc.receive(sync_message, self)
34+
35+
def receive(self, message: bytes, sender=None):
36+
self.send_stream.send_nowait((message, sender))
37+
38+
async def process_received_messages(self, task_status):
39+
task_status.started()
40+
async for message, sender in self.receive_stream:
41+
reply = handle_sync_message(message[1:], self.doc)
42+
if reply is not None:
43+
sender.receive(reply, self)
44+
45+
def send(self, message: bytes):
46+
for doc in self.connected_docs:
47+
doc.receive(create_update_message(message))
48+
49+
50+
async def test_sync():
51+
async with create_task_group() as tg:
52+
doc0 = ConnectedDoc()
53+
doc1 = ConnectedDoc()
54+
55+
await tg.start(doc0.start)
56+
await tg.start(doc1.start)
57+
58+
doc0.connect(doc1)
59+
doc1.connect(doc0)
60+
61+
array0 = doc0.doc.get("array", type=Array)
62+
array0.append(0)
63+
64+
await sleep(0.1)
65+
array1 = doc1.doc.get("array", type=Array)
66+
assert array1[0] == 0
67+
68+
tg.cancel_scope.cancel()
69+
70+
71+
def test_write_var_uint():
72+
assert write_var_uint(128) == b"\x80\x01"
73+
74+
75+
def test_decoder():
76+
with pytest.raises(RuntimeError) as exc_info:
77+
Decoder(b"").read_var_uint()
78+
assert str(exc_info.value) == "Y protocol error"
79+
80+
assert list(Decoder(b"").read_messages()) == []
81+
assert list(Decoder(b"\x00").read_messages()) == [b""]
82+
assert Decoder(b"").read_var_string() == ""
83+
assert Decoder(b"\x05Hello").read_var_string() == "Hello"

0 commit comments

Comments
 (0)