Skip to content

Commit fd74268

Browse files
brichetpre-commit-ci[bot]davidbrochart
authored
Add awareness features to handle server state (#170)
* Move the Awareness from pycrdt_websocket to pycrdt project, and add some features to it * Add tests on awareness * use google style docstring * Generate the message in test for clarity * Add docstring and tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove the unused logger * Remove typing from test * Apply suggestions from code review Co-authored-by: David Brochart <[email protected]> * Add missing docstring * Apply suggestions from code review Co-authored-by: David Brochart <[email protected]> * Remove the default user in the awareness * Remove totally the conept of user in the awareness * Add subscription id * update docstring according to review * Remove on_change callback * Observe both local and remote changes, and add a function to encode the changes * mypy * Mimic awareness.js * Check if state is set before deleting * Add create_awareness_message() and write_message() --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David Brochart <[email protected]>
1 parent bbf4b1d commit fd74268

File tree

5 files changed

+574
-31
lines changed

5 files changed

+574
-31
lines changed

docs/api_reference.md

+4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
- BaseType
99
- Array
1010
- ArrayEvent
11+
- Awareness
1112
- Decoder
1213
- Doc
14+
- Encoder
1315
- Map
1416
- MapEvent
1517
- NewTransaction
@@ -24,11 +26,13 @@
2426
- UndoManager
2527
- YMessageType
2628
- YSyncMessageType
29+
- create_awareness_message
2730
- create_sync_message
2831
- create_update_message
2932
- handle_sync_message
3033
- get_state
3134
- get_update
3235
- merge_updates
3336
- read_message
37+
- write_message
3438
- write_var_uint

python/pycrdt/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
from ._pycrdt import Subscription as Subscription
1010
from ._pycrdt import TransactionEvent as TransactionEvent
1111
from ._sync import Decoder as Decoder
12+
from ._sync import Encoder as Encoder
1213
from ._sync import YMessageType as YMessageType
1314
from ._sync import YSyncMessageType as YSyncMessageType
15+
from ._sync import create_awareness_message as create_awareness_message
1416
from ._sync import create_sync_message as create_sync_message
1517
from ._sync import create_update_message as create_update_message
1618
from ._sync import handle_sync_message as handle_sync_message
1719
from ._sync import read_message as read_message
20+
from ._sync import write_message as write_message
1821
from ._sync import write_var_uint as write_var_uint
1922
from ._text import Text as Text
2023
from ._text import TextEvent as TextEvent

python/pycrdt/_awareness.py

+163-27
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,161 @@
22

33
import json
44
import time
5-
from typing import Any
5+
from typing import Any, Callable, cast
6+
from uuid import uuid4
67

78
from ._doc import Doc
8-
from ._sync import Decoder, read_message
9+
from ._sync import Decoder, Encoder
910

1011

11-
class Awareness: # pragma: no cover
12+
class Awareness:
13+
client_id: int
14+
_meta: dict[int, dict[str, Any]]
15+
_states: dict[int, dict[str, Any]]
16+
_subscriptions: dict[str, Callable[[str, tuple[dict[str, Any], Any]], None]]
17+
1218
def __init__(self, ydoc: Doc):
19+
"""
20+
Args:
21+
ydoc: The [Doc][pycrdt.Doc] to associate the awareness with.
22+
"""
1323
self.client_id = ydoc.client_id
14-
self.meta: dict[int, dict[str, Any]] = {}
15-
self.states: dict[int, dict[str, Any]] = {}
24+
self._meta = {}
25+
self._states = {}
26+
self._subscriptions = {}
27+
self.set_local_state({})
28+
29+
@property
30+
def meta(self) -> dict[int, dict[str, Any]]:
31+
"""The clients' metadata."""
32+
return self._meta
33+
34+
@property
35+
def states(self) -> dict[int, dict[str, Any]]:
36+
"""The client states."""
37+
return self._states
38+
39+
def get_local_state(self) -> dict[str, Any] | None:
40+
"""
41+
Returns:
42+
The local state, if any.
43+
"""
44+
return self._states.get(self.client_id)
45+
46+
def set_local_state(self, state: dict[str, Any] | None) -> None:
47+
"""
48+
Updates the local state and meta, and sends the changes to subscribers.
49+
50+
Args:
51+
state: The new local state, if any.
52+
"""
53+
client_id = self.client_id
54+
curr_local_meta = self._meta.get(client_id)
55+
clock = 0 if curr_local_meta is None else curr_local_meta["clock"] + 1
56+
prev_state = self._states.get(client_id)
57+
if state is None:
58+
if client_id in self._states:
59+
del self._states[client_id]
60+
else:
61+
self._states[client_id] = state
62+
timestamp = int(time.time() * 1000)
63+
self._meta[client_id] = {"clock": clock, "lastUpdated": timestamp}
64+
added = []
65+
updated = []
66+
filtered_updated = []
67+
removed = []
68+
if state is None:
69+
removed.append(client_id)
70+
elif prev_state is None:
71+
if state is not None:
72+
added.append(client_id)
73+
else:
74+
updated.append(client_id)
75+
if prev_state != state:
76+
filtered_updated.append(client_id)
77+
if added or filtered_updated or removed:
78+
for callback in self._subscriptions.values():
79+
callback(
80+
"change",
81+
({"added": added, "updated": filtered_updated, "removed": removed}, "local"),
82+
)
83+
for callback in self._subscriptions.values():
84+
callback("update", ({"added": added, "updated": updated, "removed": removed}, "local"))
1685

17-
def get_changes(self, message: bytes) -> dict[str, Any]:
18-
message = read_message(message)
19-
decoder = Decoder(message)
86+
def set_local_state_field(self, field: str, value: Any) -> None:
87+
"""
88+
Sets a local state field.
89+
90+
Args:
91+
field: The field of the local state to set.
92+
value: The value associated with the field.
93+
"""
94+
state = self.get_local_state()
95+
if state is not None:
96+
state[field] = value
97+
self.set_local_state(state)
98+
99+
def encode_awareness_update(self, client_ids: list[int]) -> bytes:
100+
"""
101+
Creates an encoded awareness update of the clients given by their IDs.
102+
103+
Args:
104+
client_ids: The list of client IDs for which to create an update.
105+
106+
Returns:
107+
The encoded awareness update.
108+
"""
109+
encoder = Encoder()
110+
encoder.write_var_uint(len(client_ids))
111+
for client_id in client_ids:
112+
state = self._states.get(client_id)
113+
clock = cast(int, self._meta.get(client_id, {}).get("clock"))
114+
encoder.write_var_uint(client_id)
115+
encoder.write_var_uint(clock)
116+
encoder.write_var_string(json.dumps(state, separators=(",", ":")))
117+
return encoder.to_bytes()
118+
119+
def apply_awareness_update(self, update: bytes, origin: Any) -> None:
120+
"""
121+
Applies the binary update and notifies subscribers with changes.
122+
123+
Args:
124+
update: The binary update.
125+
origin: The origin of the update.
126+
"""
127+
decoder = Decoder(update)
20128
timestamp = int(time.time() * 1000)
21129
added = []
22130
updated = []
23131
filtered_updated = []
24132
removed = []
25-
states = []
26133
length = decoder.read_var_uint()
27134
for _ in range(length):
28135
client_id = decoder.read_var_uint()
29136
clock = decoder.read_var_uint()
30137
state_str = decoder.read_var_string()
31138
state = None if not state_str else json.loads(state_str)
32-
if state is not None:
33-
states.append(state)
34-
client_meta = self.meta.get(client_id)
35-
prev_state = self.states.get(client_id)
139+
client_meta = self._meta.get(client_id)
140+
prev_state = self._states.get(client_id)
36141
curr_clock = 0 if client_meta is None else client_meta["clock"]
37142
if curr_clock < clock or (
38-
curr_clock == clock and state is None and client_id in self.states
143+
curr_clock == clock and state is None and client_id in self._states
39144
):
40145
if state is None:
41-
if client_id == self.client_id and self.states.get(client_id) is not None:
146+
# Never let a remote client remove this local state.
147+
if client_id == self.client_id and self.get_local_state() is not None:
148+
# Remote client removed the local state. Do not remove state.
149+
# Broadcast a message indicating that this client still exists by increasing
150+
# the clock.
42151
clock += 1
43152
else:
44-
if client_id in self.states:
45-
del self.states[client_id]
153+
if client_id in self._states:
154+
del self._states[client_id]
46155
else:
47-
self.states[client_id] = state
48-
self.meta[client_id] = {
156+
self._states[client_id] = state
157+
self._meta[client_id] = {
49158
"clock": clock,
50-
"last_updated": timestamp,
159+
"lastUpdated": timestamp,
51160
}
52161
if client_meta is None and state is not None:
53162
added.append(client_id)
@@ -57,10 +166,37 @@ def get_changes(self, message: bytes) -> dict[str, Any]:
57166
if state != prev_state:
58167
filtered_updated.append(client_id)
59168
updated.append(client_id)
60-
return {
61-
"added": added,
62-
"updated": updated,
63-
"filtered_updated": filtered_updated,
64-
"removed": removed,
65-
"states": states,
66-
}
169+
if added or filtered_updated or removed:
170+
for callback in self._subscriptions.values():
171+
callback(
172+
"change",
173+
({"added": added, "updated": filtered_updated, "removed": removed}, origin),
174+
)
175+
if added or updated or removed:
176+
for callback in self._subscriptions.values():
177+
callback(
178+
"update", ({"added": added, "updated": updated, "removed": removed}, origin)
179+
)
180+
181+
def observe(self, callback: Callable[[str, tuple[dict[str, Any], Any]], None]) -> str:
182+
"""
183+
Registers the given callback to awareness changes.
184+
185+
Args:
186+
callback: The callback to call with the awareness changes.
187+
188+
Returns:
189+
The subscription ID that can be used to unobserve.
190+
"""
191+
id = str(uuid4())
192+
self._subscriptions[id] = callback
193+
return id
194+
195+
def unobserve(self, id: str) -> None:
196+
"""
197+
Unregisters the given subscription ID from awareness changes.
198+
199+
Args:
200+
id: The subscription ID to unregister.
201+
"""
202+
del self._subscriptions[id]

python/pycrdt/_sync.py

+67-4
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,31 @@ def write_var_uint(num: int) -> bytes:
5454
return bytes(res)
5555

5656

57+
def create_awareness_message(data: bytes) -> bytes:
58+
"""
59+
Creates an [AWARENESS][pycrdt.YMessageType] message.
60+
61+
Args:
62+
data: The data to send in the message.
63+
64+
Returns:
65+
The [AWARENESS][pycrdt.YMessageType] message.
66+
"""
67+
return bytes([YMessageType.AWARENESS]) + write_message(data)
68+
69+
5770
def create_message(data: bytes, msg_type: int) -> bytes:
5871
"""
59-
Creates a binary Y message.
72+
Creates a SYNC message.
6073
6174
Args:
6275
data: The data to send in the message.
63-
msg_type: The [message type][pycrdt.YSyncMessageType].
76+
msg_type: The [SYNC message type][pycrdt.YSyncMessageType].
6477
6578
Returns:
66-
The binary Y message.
79+
The SYNC message.
6780
"""
68-
return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data
81+
return bytes([YMessageType.SYNC, msg_type]) + write_message(data)
6982

7083

7184
def create_sync_step1_message(data: bytes) -> bytes:
@@ -110,6 +123,43 @@ def create_update_message(data: bytes) -> bytes:
110123
return create_message(data, YSyncMessageType.SYNC_UPDATE)
111124

112125

126+
class Encoder:
127+
"""
128+
An encoder capable of writing messages to a binary stream.
129+
"""
130+
131+
stream: list[bytes]
132+
133+
def __init__(self) -> None:
134+
self.stream = []
135+
136+
def write_var_uint(self, num: int) -> None:
137+
"""
138+
Encodes a number.
139+
140+
Args:
141+
num: The number to encode.
142+
"""
143+
self.stream.append(write_var_uint(num))
144+
145+
def write_var_string(self, text: str) -> None:
146+
"""
147+
Encodes a string.
148+
149+
Args:
150+
text: The string to encode.
151+
"""
152+
self.stream.append(write_var_uint(len(text)))
153+
self.stream.append(text.encode())
154+
155+
def to_bytes(self) -> bytes:
156+
"""
157+
Returns:
158+
The binary stream.
159+
"""
160+
return b"".join(self.stream)
161+
162+
113163
class Decoder:
114164
"""
115165
A decoder capable of reading messages from a byte stream.
@@ -205,6 +255,19 @@ def read_message(stream: bytes) -> bytes:
205255
return message
206256

207257

258+
def write_message(stream: bytes) -> bytes:
259+
"""
260+
Writes a stream in a message.
261+
262+
Args:
263+
stream: The byte stream to write in a message.
264+
265+
Returns:
266+
The message containing the stream.
267+
"""
268+
return write_var_uint(len(stream)) + stream
269+
270+
208271
def handle_sync_message(message: bytes, ydoc: Doc) -> bytes | None:
209272
"""
210273
Processes a [synchronization message][pycrdt.YSyncMessageType] on a document.

0 commit comments

Comments
 (0)