2
2
3
3
import json
4
4
import time
5
- from typing import Any
5
+ from typing import Any , Callable , cast
6
+ from uuid import uuid4
6
7
7
8
from ._doc import Doc
8
- from ._sync import Decoder , read_message
9
+ from ._sync import Decoder , Encoder
9
10
10
11
11
- class Awareness : # pragma: no cover
12
+ class Awareness :
13
+ client_id : int
14
+ _meta : dict [int , dict [str , Any ]]
15
+ _states : dict [int , dict [str , Any ]]
16
+ _subscriptions : dict [str , Callable [[str , tuple [dict [str , Any ], Any ]], None ]]
17
+
12
18
def __init__ (self , ydoc : Doc ):
19
+ """
20
+ Args:
21
+ ydoc: The [Doc][pycrdt.Doc] to associate the awareness with.
22
+ """
13
23
self .client_id = ydoc .client_id
14
- self .meta : dict [int , dict [str , Any ]] = {}
15
- self .states : dict [int , dict [str , Any ]] = {}
24
+ self ._meta = {}
25
+ self ._states = {}
26
+ self ._subscriptions = {}
27
+ self .set_local_state ({})
28
+
29
+ @property
30
+ def meta (self ) -> dict [int , dict [str , Any ]]:
31
+ """The clients' metadata."""
32
+ return self ._meta
33
+
34
+ @property
35
+ def states (self ) -> dict [int , dict [str , Any ]]:
36
+ """The client states."""
37
+ return self ._states
38
+
39
+ def get_local_state (self ) -> dict [str , Any ] | None :
40
+ """
41
+ Returns:
42
+ The local state, if any.
43
+ """
44
+ return self ._states .get (self .client_id )
45
+
46
+ def set_local_state (self , state : dict [str , Any ] | None ) -> None :
47
+ """
48
+ Updates the local state and meta, and sends the changes to subscribers.
49
+
50
+ Args:
51
+ state: The new local state, if any.
52
+ """
53
+ client_id = self .client_id
54
+ curr_local_meta = self ._meta .get (client_id )
55
+ clock = 0 if curr_local_meta is None else curr_local_meta ["clock" ] + 1
56
+ prev_state = self ._states .get (client_id )
57
+ if state is None :
58
+ if client_id in self ._states :
59
+ del self ._states [client_id ]
60
+ else :
61
+ self ._states [client_id ] = state
62
+ timestamp = int (time .time () * 1000 )
63
+ self ._meta [client_id ] = {"clock" : clock , "lastUpdated" : timestamp }
64
+ added = []
65
+ updated = []
66
+ filtered_updated = []
67
+ removed = []
68
+ if state is None :
69
+ removed .append (client_id )
70
+ elif prev_state is None :
71
+ if state is not None :
72
+ added .append (client_id )
73
+ else :
74
+ updated .append (client_id )
75
+ if prev_state != state :
76
+ filtered_updated .append (client_id )
77
+ if added or filtered_updated or removed :
78
+ for callback in self ._subscriptions .values ():
79
+ callback (
80
+ "change" ,
81
+ ({"added" : added , "updated" : filtered_updated , "removed" : removed }, "local" ),
82
+ )
83
+ for callback in self ._subscriptions .values ():
84
+ callback ("update" , ({"added" : added , "updated" : updated , "removed" : removed }, "local" ))
16
85
17
- def get_changes (self , message : bytes ) -> dict [str , Any ]:
18
- message = read_message (message )
19
- decoder = Decoder (message )
86
+ def set_local_state_field (self , field : str , value : Any ) -> None :
87
+ """
88
+ Sets a local state field.
89
+
90
+ Args:
91
+ field: The field of the local state to set.
92
+ value: The value associated with the field.
93
+ """
94
+ state = self .get_local_state ()
95
+ if state is not None :
96
+ state [field ] = value
97
+ self .set_local_state (state )
98
+
99
+ def encode_awareness_update (self , client_ids : list [int ]) -> bytes :
100
+ """
101
+ Creates an encoded awareness update of the clients given by their IDs.
102
+
103
+ Args:
104
+ client_ids: The list of client IDs for which to create an update.
105
+
106
+ Returns:
107
+ The encoded awareness update.
108
+ """
109
+ encoder = Encoder ()
110
+ encoder .write_var_uint (len (client_ids ))
111
+ for client_id in client_ids :
112
+ state = self ._states .get (client_id )
113
+ clock = cast (int , self ._meta .get (client_id , {}).get ("clock" ))
114
+ encoder .write_var_uint (client_id )
115
+ encoder .write_var_uint (clock )
116
+ encoder .write_var_string (json .dumps (state , separators = ("," , ":" )))
117
+ return encoder .to_bytes ()
118
+
119
+ def apply_awareness_update (self , update : bytes , origin : Any ) -> None :
120
+ """
121
+ Applies the binary update and notifies subscribers with changes.
122
+
123
+ Args:
124
+ update: The binary update.
125
+ origin: The origin of the update.
126
+ """
127
+ decoder = Decoder (update )
20
128
timestamp = int (time .time () * 1000 )
21
129
added = []
22
130
updated = []
23
131
filtered_updated = []
24
132
removed = []
25
- states = []
26
133
length = decoder .read_var_uint ()
27
134
for _ in range (length ):
28
135
client_id = decoder .read_var_uint ()
29
136
clock = decoder .read_var_uint ()
30
137
state_str = decoder .read_var_string ()
31
138
state = None if not state_str else json .loads (state_str )
32
- if state is not None :
33
- states .append (state )
34
- client_meta = self .meta .get (client_id )
35
- prev_state = self .states .get (client_id )
139
+ client_meta = self ._meta .get (client_id )
140
+ prev_state = self ._states .get (client_id )
36
141
curr_clock = 0 if client_meta is None else client_meta ["clock" ]
37
142
if curr_clock < clock or (
38
- curr_clock == clock and state is None and client_id in self .states
143
+ curr_clock == clock and state is None and client_id in self ._states
39
144
):
40
145
if state is None :
41
- if client_id == self .client_id and self .states .get (client_id ) is not None :
146
+ # Never let a remote client remove this local state.
147
+ if client_id == self .client_id and self .get_local_state () is not None :
148
+ # Remote client removed the local state. Do not remove state.
149
+ # Broadcast a message indicating that this client still exists by increasing
150
+ # the clock.
42
151
clock += 1
43
152
else :
44
- if client_id in self .states :
45
- del self .states [client_id ]
153
+ if client_id in self ._states :
154
+ del self ._states [client_id ]
46
155
else :
47
- self .states [client_id ] = state
48
- self .meta [client_id ] = {
156
+ self ._states [client_id ] = state
157
+ self ._meta [client_id ] = {
49
158
"clock" : clock ,
50
- "last_updated " : timestamp ,
159
+ "lastUpdated " : timestamp ,
51
160
}
52
161
if client_meta is None and state is not None :
53
162
added .append (client_id )
@@ -57,10 +166,37 @@ def get_changes(self, message: bytes) -> dict[str, Any]:
57
166
if state != prev_state :
58
167
filtered_updated .append (client_id )
59
168
updated .append (client_id )
60
- return {
61
- "added" : added ,
62
- "updated" : updated ,
63
- "filtered_updated" : filtered_updated ,
64
- "removed" : removed ,
65
- "states" : states ,
66
- }
169
+ if added or filtered_updated or removed :
170
+ for callback in self ._subscriptions .values ():
171
+ callback (
172
+ "change" ,
173
+ ({"added" : added , "updated" : filtered_updated , "removed" : removed }, origin ),
174
+ )
175
+ if added or updated or removed :
176
+ for callback in self ._subscriptions .values ():
177
+ callback (
178
+ "update" , ({"added" : added , "updated" : updated , "removed" : removed }, origin )
179
+ )
180
+
181
+ def observe (self , callback : Callable [[str , tuple [dict [str , Any ], Any ]], None ]) -> str :
182
+ """
183
+ Registers the given callback to awareness changes.
184
+
185
+ Args:
186
+ callback: The callback to call with the awareness changes.
187
+
188
+ Returns:
189
+ The subscription ID that can be used to unobserve.
190
+ """
191
+ id = str (uuid4 ())
192
+ self ._subscriptions [id ] = callback
193
+ return id
194
+
195
+ def unobserve (self , id : str ) -> None :
196
+ """
197
+ Unregisters the given subscription ID from awareness changes.
198
+
199
+ Args:
200
+ id: The subscription ID to unregister.
201
+ """
202
+ del self ._subscriptions [id ]
0 commit comments