20
20
from contextlib import contextmanager
21
21
from types import TracebackType
22
22
from typing import (
23
+ TYPE_CHECKING ,
23
24
AsyncContextManager ,
24
25
ContextManager ,
25
26
Dict ,
49
50
from synapse .storage .types import Cursor
50
51
from synapse .storage .util .sequence import PostgresSequenceGenerator
51
52
53
+ if TYPE_CHECKING :
54
+ from synapse .notifier import ReplicationNotifier
55
+
52
56
logger = logging .getLogger (__name__ )
53
57
54
58
@@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
182
186
def __init__ (
183
187
self ,
184
188
db_conn : LoggingDatabaseConnection ,
189
+ notifier : "ReplicationNotifier" ,
185
190
table : str ,
186
191
column : str ,
187
192
extra_tables : Iterable [Tuple [str , str ]] = (),
@@ -205,6 +210,8 @@ def __init__(
205
210
# The key and values are the same, but we never look at the values.
206
211
self ._unfinished_ids : OrderedDict [int , int ] = OrderedDict ()
207
212
213
+ self ._notifier = notifier
214
+
208
215
def advance (self , instance_name : str , new_id : int ) -> None :
209
216
# Advance should never be called on a writer instance, only over replication
210
217
if self ._is_writer :
@@ -227,6 +234,8 @@ def manager() -> Generator[int, None, None]:
227
234
with self ._lock :
228
235
self ._unfinished_ids .pop (next_id )
229
236
237
+ self ._notifier .notify_replication ()
238
+
230
239
return _AsyncCtxManagerWrapper (manager ())
231
240
232
241
def get_next_mult (self , n : int ) -> AsyncContextManager [Sequence [int ]]:
@@ -250,6 +259,8 @@ def manager() -> Generator[Sequence[int], None, None]:
250
259
for next_id in next_ids :
251
260
self ._unfinished_ids .pop (next_id )
252
261
262
+ self ._notifier .notify_replication ()
263
+
253
264
return _AsyncCtxManagerWrapper (manager ())
254
265
255
266
def get_current_token (self ) -> int :
@@ -296,6 +307,7 @@ def __init__(
296
307
self ,
297
308
db_conn : LoggingDatabaseConnection ,
298
309
db : DatabasePool ,
310
+ notifier : "ReplicationNotifier" ,
299
311
stream_name : str ,
300
312
instance_name : str ,
301
313
tables : List [Tuple [str , str , str ]],
@@ -304,6 +316,7 @@ def __init__(
304
316
positive : bool = True ,
305
317
) -> None :
306
318
self ._db = db
319
+ self ._notifier = notifier
307
320
self ._stream_name = stream_name
308
321
self ._instance_name = instance_name
309
322
self ._positive = positive
@@ -535,7 +548,9 @@ def get_next(self) -> AsyncContextManager[int]:
535
548
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
536
549
# controls the return type. If `None` or omitted, the context manager yields
537
550
# a single integer stream_id; otherwise it yields a list of stream_ids.
538
- return cast (AsyncContextManager [int ], _MultiWriterCtxManager (self ))
551
+ return cast (
552
+ AsyncContextManager [int ], _MultiWriterCtxManager (self , self ._notifier )
553
+ )
539
554
540
555
def get_next_mult (self , n : int ) -> AsyncContextManager [List [int ]]:
541
556
# If we have a list of instances that are allowed to write to this
@@ -544,7 +559,10 @@ def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
544
559
raise Exception ("Tried to allocate stream ID on non-writer" )
545
560
546
561
# Cast safety: see get_next.
547
- return cast (AsyncContextManager [List [int ]], _MultiWriterCtxManager (self , n ))
562
+ return cast (
563
+ AsyncContextManager [List [int ]],
564
+ _MultiWriterCtxManager (self , self ._notifier , n ),
565
+ )
548
566
549
567
def get_next_txn (self , txn : LoggingTransaction ) -> int :
550
568
"""
@@ -563,6 +581,7 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:
563
581
564
582
txn .call_after (self ._mark_id_as_finished , next_id )
565
583
txn .call_on_exception (self ._mark_id_as_finished , next_id )
584
+ txn .call_after (self ._notifier .notify_replication )
566
585
567
586
# Update the `stream_positions` table with newly updated stream
568
587
# ID (unless self._writers is not set in which case we don't
@@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
787
806
"""Async context manager returned by MultiWriterIdGenerator"""
788
807
789
808
id_gen : MultiWriterIdGenerator
809
+ notifier : "ReplicationNotifier"
790
810
multiple_ids : Optional [int ] = None
791
811
stream_ids : List [int ] = attr .Factory (list )
792
812
@@ -814,6 +834,8 @@ async def __aexit__(
814
834
for i in self .stream_ids :
815
835
self .id_gen ._mark_id_as_finished (i )
816
836
837
+ self .notifier .notify_replication ()
838
+
817
839
if exc_type is not None :
818
840
return False
819
841
0 commit comments