-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
228ffb9
commit 2bf447f
Showing
4 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from __future__ import annotations | ||
|
||
from enum import IntEnum | ||
|
||
from pycrdt import Doc | ||
|
||
|
||
class YMessageType(IntEnum): | ||
SYNC = 0 | ||
AWARENESS = 1 | ||
|
||
|
||
class YSyncMessageType(IntEnum): | ||
SYNC_STEP1 = 0 | ||
SYNC_STEP2 = 1 | ||
SYNC_UPDATE = 2 | ||
|
||
|
||
def write_var_uint(num: int) -> bytes: | ||
res = [] | ||
while num > 127: | ||
res.append(128 | (127 & num)) | ||
num >>= 7 | ||
res.append(num) | ||
return bytes(res) | ||
|
||
|
||
def create_message(data: bytes, msg_type: int) -> bytes: | ||
return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data | ||
|
||
|
||
def create_sync_step1_message(data: bytes) -> bytes: | ||
return create_message(data, YSyncMessageType.SYNC_STEP1) | ||
|
||
|
||
def create_sync_step2_message(data: bytes) -> bytes: | ||
return create_message(data, YSyncMessageType.SYNC_STEP2) | ||
|
||
|
||
def create_update_message(data: bytes) -> bytes: | ||
return create_message(data, YSyncMessageType.SYNC_UPDATE) | ||
|
||
|
||
class Decoder: | ||
def __init__(self, stream: bytes): | ||
self.stream = stream | ||
self.length = len(stream) | ||
self.i0 = 0 | ||
|
||
def read_var_uint(self) -> int: | ||
if self.length <= 0: | ||
raise RuntimeError("Y protocol error") | ||
uint = 0 | ||
i = 0 | ||
while True: | ||
byte = self.stream[self.i0] | ||
uint += (byte & 127) << i | ||
i += 7 | ||
self.i0 += 1 | ||
self.length -= 1 | ||
if byte < 128: | ||
break | ||
return uint | ||
|
||
def read_message(self) -> bytes | None: | ||
if self.length == 0: | ||
return None | ||
length = self.read_var_uint() | ||
if length == 0: | ||
return b"" | ||
i1 = self.i0 + length | ||
message = self.stream[self.i0 : i1] | ||
self.i0 = i1 | ||
self.length -= length | ||
return message | ||
|
||
def read_messages(self): | ||
while True: | ||
message = self.read_message() | ||
if message is None: | ||
return | ||
yield message | ||
|
||
def read_var_string(self): | ||
message = self.read_message() | ||
if message is None: | ||
return "" | ||
return message.decode("utf-8") | ||
|
||
|
||
def read_message(stream: bytes) -> bytes: | ||
message = Decoder(stream).read_message() | ||
assert message is not None | ||
return message | ||
|
||
|
||
def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None: | ||
message_type = message[0] | ||
msg = message[1:] | ||
|
||
if message_type == YSyncMessageType.SYNC_STEP1: | ||
state = read_message(msg) | ||
update = ydoc.get_update(state) | ||
reply = create_sync_step2_message(update) | ||
return reply | ||
|
||
if message_type in ( | ||
YSyncMessageType.SYNC_STEP2, | ||
YSyncMessageType.SYNC_UPDATE, | ||
): | ||
update = read_message(msg) | ||
# Ignore empty updates | ||
if update != b"\x00\x00": | ||
ydoc.apply_update(update) | ||
|
||
return None | ||
|
||
|
||
def create_sync_message(ydoc: Doc) -> bytes: | ||
state = ydoc.get_state() | ||
message = create_sync_step1_message(state) | ||
return message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import pytest | ||
from anyio import TASK_STATUS_IGNORED, create_memory_object_stream, create_task_group, sleep | ||
from anyio.abc import TaskStatus | ||
from pycrdt import ( | ||
Array, | ||
Doc, | ||
create_sync_message, | ||
create_update_message, | ||
handle_sync_message, | ||
) | ||
from pycrdt._sync import Decoder, write_var_uint | ||
|
||
pytestmark = pytest.mark.anyio | ||
|
||
|
||
class ConnectedDoc: | ||
def __init__(self): | ||
self.doc = Doc() | ||
self.doc.observe(lambda event: self.send(event.update)) | ||
self.connected_docs = [] | ||
self.send_stream, self.receive_stream = create_memory_object_stream[bytes]( | ||
max_buffer_size=1024 | ||
) | ||
|
||
async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): | ||
async with create_task_group() as tg: | ||
task_status.started() | ||
tg.start_soon(self.process_received_messages) | ||
|
||
def connect(self, *connected_docs): | ||
self.connected_docs += connected_docs | ||
sync_message = create_sync_message(self.doc) | ||
for connected_doc in connected_docs: | ||
connected_doc.receive(sync_message, self) | ||
|
||
def receive(self, message: bytes, sender=None): | ||
self.send_stream.send_nowait((message, sender)) | ||
|
||
async def process_received_messages(self): | ||
async for message, sender in self.receive_stream: | ||
reply = handle_sync_message(message[1:], self.doc) | ||
if reply is not None: | ||
sender.receive(reply, self) | ||
|
||
def send(self, message: bytes): | ||
for doc in self.connected_docs: | ||
doc.receive(create_update_message(message)) | ||
|
||
|
||
async def test_sync(): | ||
async with create_task_group() as tg: | ||
doc0 = ConnectedDoc() | ||
doc1 = ConnectedDoc() | ||
|
||
await tg.start(doc0.start) | ||
await tg.start(doc1.start) | ||
|
||
doc0.connect(doc1) | ||
doc1.connect(doc0) | ||
|
||
array0 = doc0.doc.get("array", type=Array) | ||
array0.append(0) | ||
await sleep(0.1) | ||
|
||
array1 = doc1.doc.get("array", type=Array) | ||
assert array1[0] == 0 | ||
|
||
# doc2 only connects to doc0 | ||
# but since doc0 and doc1 are connected, | ||
# doc2 is indirectly connected to doc1 | ||
doc2 = ConnectedDoc() | ||
await tg.start(doc2.start) | ||
doc2.connect(doc0) | ||
doc0.connect(doc2) | ||
await sleep(0.1) | ||
array2 = doc2.doc.get("array", type=Array) | ||
assert array2[0] == 0 | ||
array2.append(1) | ||
await sleep(0.1) | ||
assert array0[1] == 1 | ||
assert array1[1] == 1 | ||
|
||
tg.cancel_scope.cancel() | ||
|
||
|
||
def test_write_var_uint(): | ||
assert write_var_uint(128) == b"\x80\x01" | ||
|
||
|
||
def test_decoder(): | ||
with pytest.raises(RuntimeError) as exc_info: | ||
Decoder(b"").read_var_uint() | ||
assert str(exc_info.value) == "Y protocol error" | ||
|
||
assert list(Decoder(b"").read_messages()) == [] | ||
assert list(Decoder(b"\x00").read_messages()) == [b""] | ||
assert Decoder(b"").read_var_string() == "" | ||
assert Decoder(b"\x05Hello").read_var_string() == "Hello" |