@@ -98,6 +98,11 @@ async def get_partial_current_state_deltas(
9898 """
9999 prev_stream_id = int (prev_stream_id )
100100
101+ if limit <= 0 :
102+ raise ValueError (
103+ "Invalid `limit` passed to `get_partial_current_state_deltas"
104+ )
105+
101106 # check we're not going backwards
102107 assert prev_stream_id <= max_stream_id , (
103108 f"New stream id { max_stream_id } is smaller than prev stream id { prev_stream_id } "
@@ -114,46 +119,23 @@ async def get_partial_current_state_deltas(
114119 def get_current_state_deltas_txn (
115120 txn : LoggingTransaction ,
116121 ) -> Tuple [int , List [StateDelta ]]:
117- # First we calculate the max stream id that will give us less than
118- # N results.
119- # We limit the number of returned stream_id entries to ensure we
120- # don't select toooo many.
121122 sql = """
122- SELECT stream_id, count(*)
123+ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
123124 FROM current_state_delta_stream
124- WHERE stream_id > ? AND stream_id <= ?
125- GROUP BY stream_id
125+ WHERE ? < stream_id AND stream_id <= ?
126126 ORDER BY stream_id ASC
127127 LIMIT ?
128128 """
129129 txn .execute (sql , (prev_stream_id , max_stream_id , limit ))
130+ rows = txn .fetchall ()
131+
132+ # In the case that we hit the given `limit` rather than fetching the
133+ # most recent rows, return the `stream_id` of the last row.
134+ #
135+ # With this, the caller knows from what stream_id to call this
136+ # function again with.
137+ clipped_stream_id = rows [- 1 ][0 ]
130138
131- total = 0
132-
133- for stream_id , count in txn :
134- total += count
135-
136- if total >= limit :
137- # We limit the number of returned entries to ensure we don't
138- # select toooo many.
139- logger .debug (
140- "Clipping current_state_delta_stream rows to stream_id %i" ,
141- stream_id ,
142- )
143- clipped_stream_id = stream_id
144- break
145- else :
146- # if there's no problem, we may as well go right up to the max_stream_id
147- clipped_stream_id = max_stream_id
148-
149- # Now actually get the deltas
150- sql = """
151- SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
152- FROM current_state_delta_stream
153- WHERE ? < stream_id AND stream_id <= ?
154- ORDER BY stream_id ASC
155- """
156- txn .execute (sql , (prev_stream_id , clipped_stream_id ))
157139 return clipped_stream_id , [
158140 StateDelta (
159141 stream_id = row [0 ],
@@ -163,7 +145,7 @@ def get_current_state_deltas_txn(
163145 event_id = row [4 ],
164146 prev_event_id = row [5 ],
165147 )
166- for row in txn . fetchall ()
148+ for row in rows
167149 ]
168150
169151 return await self .db_pool .runInteraction (
0 commit comments