Skip to content

Commit b6857d7

Browse files
Add awareness (#171)
1 parent 33ab6ca commit b6857d7

File tree

7 files changed

+70
-9
lines changed

7 files changed

+70
-9
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
run: pytest --color=yes -v tests
6767

6868
- name: Run code coverage
69-
if: ${{ (matrix.python-version == '3.12') && (matrix.os == 'ubuntu-latest') }}
69+
if: ${{ (matrix.python-version == '3.12') && (matrix.os == 'ubuntu') }}
7070
run: |
7171
coverage run -m pytest tests
7272
coverage report --show-missing --fail-under=100

python/pycrdt/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ._array import Array as Array
22
from ._array import ArrayEvent as ArrayEvent
3+
from ._awareness import Awareness as Awareness
34
from ._doc import Doc as Doc
45
from ._map import Map as Map
56
from ._map import MapEvent as MapEvent

python/pycrdt/_array.py

-3
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,6 @@ def __init__(self, array: Array):
382382
self.length = len(array)
383383
self.idx = 0
384384

385-
def __iter__(self) -> ArrayIterator:
386-
return self
387-
388385
def __next__(self) -> Any:
389386
if self.idx == self.length:
390387
raise StopIteration

python/pycrdt/_awareness.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import time
5+
from typing import Any
6+
7+
from ._doc import Doc
8+
from ._sync import Decoder, read_message
9+
10+
11+
class Awareness: # pragma: no cover
12+
def __init__(self, ydoc: Doc):
13+
self.client_id = ydoc.client_id
14+
self.meta: dict[int, dict[str, Any]] = {}
15+
self.states: dict[int, dict[str, Any]] = {}
16+
17+
def get_changes(self, message: bytes) -> dict[str, Any]:
18+
message = read_message(message)
19+
decoder = Decoder(message)
20+
timestamp = int(time.time() * 1000)
21+
added = []
22+
updated = []
23+
filtered_updated = []
24+
removed = []
25+
states = []
26+
length = decoder.read_var_uint()
27+
for _ in range(length):
28+
client_id = decoder.read_var_uint()
29+
clock = decoder.read_var_uint()
30+
state_str = decoder.read_var_string()
31+
state = None if not state_str else json.loads(state_str)
32+
if state is not None:
33+
states.append(state)
34+
client_meta = self.meta.get(client_id)
35+
prev_state = self.states.get(client_id)
36+
curr_clock = 0 if client_meta is None else client_meta["clock"]
37+
if curr_clock < clock or (
38+
curr_clock == clock and state is None and client_id in self.states
39+
):
40+
if state is None:
41+
if client_id == self.client_id and self.states.get(client_id) is not None:
42+
clock += 1
43+
else:
44+
if client_id in self.states:
45+
del self.states[client_id]
46+
else:
47+
self.states[client_id] = state
48+
self.meta[client_id] = {
49+
"clock": clock,
50+
"last_updated": timestamp,
51+
}
52+
if client_meta is None and state is not None:
53+
added.append(client_id)
54+
elif client_meta is not None and state is None:
55+
removed.append(client_id)
56+
elif state is not None:
57+
if state != prev_state:
58+
filtered_updated.append(client_id)
59+
updated.append(client_id)
60+
return {
61+
"added": added,
62+
"updated": updated,
63+
"filtered_updated": filtered_updated,
64+
"removed": removed,
65+
"states": states,
66+
}

python/pycrdt/_sync.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import IntEnum
44
from typing import Iterator
55

6-
from pycrdt import Doc
6+
from ._doc import Doc
77

88

99
class YMessageType(IntEnum):

python/pycrdt/_text.py

-3
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,6 @@ def __init__(self, text: Text):
275275
self.length = len(text)
276276
self.idx = 0
277277

278-
def __iter__(self) -> TextIterator:
279-
return self
280-
281278
def __next__(self) -> str:
282279
if self.idx == self.length:
283280
raise StopIteration

tests/test_transaction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pycrdt import Array, Doc, Map, Text
88

99
if sys.version_info < (3, 11):
10-
from exceptiongroup import ExceptionGroup
10+
from exceptiongroup import ExceptionGroup # pragma: no cover
1111

1212
pytestmark = pytest.mark.anyio
1313

0 commit comments

Comments
 (0)