Skip to content

Commit

Permalink
Add cursor at current function
Browse files Browse the repository at this point in the history
  • Loading branch information
another-rex committed Aug 27, 2024
1 parent 4afea58 commit 94d0515
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
7 changes: 3 additions & 4 deletions gcp/api/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ class QueryCursor:
Attributes:
query_number: This cursor is specifically for the Nth ndb datastore
query in the current query request.
query in the current query request. (Starts from 1)
ndb_cursor: Get the internal ndb_cursor. This could be None.
ended: Whether this cursor is for a query that has finished returning data.
"""

_ndb_cursor: ndb.Cursor | None = None
_cursor_state: _QueryCursorState = _QueryCursorState.ENDED
# The first query is numbered 1. This is because the query counter is
# incremented **before** the query and the query number being used.
query_number: int = 1

@classmethod
Expand Down Expand Up @@ -142,7 +144,4 @@ def url_safe_encode(self) -> str | None:
# a token in the response
return None

if self.query_number == 0:
return cursor_part

return str(self.query_number) + _METADATA_SEPARATOR + cursor_part
6 changes: 3 additions & 3 deletions gcp/api/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,22 +775,22 @@ def test_query_pagination_no_ecosystem(self):
'package': {
'name': 'django',
},
'version': '5.0',
'version': '5.0.1',
}),
timeout=_TIMEOUT)

result = response.json()
vulns_first = set(v['id'] for v in result['vulns'])
self.assertIn('next_page_token', result)
self.assertTrue(str.startswith(result['next_page_token'], '1:'))
self.assertTrue(str.startswith(result['next_page_token'], '2:'))

response = requests.post(
_api() + _BASE_QUERY,
data=json.dumps({
'package': {
'name': 'django',
},
'version': '5.0',
'version': '5.0.1',
'page_token': result['next_page_token'],
}),
timeout=_TIMEOUT)
Expand Down
22 changes: 16 additions & 6 deletions gcp/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,16 @@ def should_skip_query(self):
return (self.query_counter < self.input_cursor.query_number or
not self.output_cursor.ended)

def cursor_at_current(self) -> ndb.Cursor | None:
"""
Return the cursor if the stored cursor is for the current query.
"""
if self.input_cursor.query_number == self.query_counter:
return self.input_cursor.ndb_cursor
else:
return None


def save_cursor_at_page_break(self, it: ndb.QueryIterator):
"""
Saves the cursor at the current page break position
Expand Down Expand Up @@ -853,7 +863,7 @@ def query_by_commit(context: QueryContext,

bug_ids = []
it: ndb.QueryIterator = query.iter(
keys_only=True, start_cursor=context.input_cursor.ndb_cursor)
keys_only=True, start_cursor=context.cursor_at_current())

while (yield it.has_next_async()):
if context.should_break_page(len(bug_ids)):
Expand Down Expand Up @@ -1025,7 +1035,7 @@ def _query_by_semver(context: QueryContext, query: ndb.Query,
return []

it: ndb.QueryIterator = query.iter(
start_cursor=context.input_cursor.ndb_cursor)
start_cursor=context.cursor_at_current())

while (yield it.has_next_async()):
if context.should_break_page(len(results)):
Expand Down Expand Up @@ -1107,7 +1117,7 @@ def query_by_generic_helper(context: QueryContext, base_query: ndb.Query,
return []

it: ndb.QueryIterator = query.iter(
start_cursor=context.input_cursor.ndb_cursor)
start_cursor=context.cursor_at_current())

while (yield it.has_next_async()):
if context.should_break_page(len(results)):
Expand Down Expand Up @@ -1235,7 +1245,7 @@ def _query_by_comparing_versions(context: QueryContext, query: ndb.Query,
context: QueryContext for the current query.
query: A partially completed ndb.Query object which only needs
version filters to be added before query is performed.
ecosystem: Optional ecosystem of the package to query.
ecosystem: Required ecosystem of the package to query.
version: The version str to query by.
Returns:
Expand All @@ -1248,7 +1258,7 @@ def _query_by_comparing_versions(context: QueryContext, query: ndb.Query,
return []

it: ndb.QueryIterator = query.iter(
start_cursor=context.input_cursor.ndb_cursor)
start_cursor=context.cursor_at_current())

# Checks if the query specifies a release (e.g., "Debian:12")
has_release = ':' in ecosystem
Expand Down Expand Up @@ -1331,7 +1341,7 @@ def query_by_package(context: QueryContext, package_name: str | None,
return []

it: ndb.QueryIterator = query.iter(
start_cursor=context.input_cursor.ndb_cursor)
start_cursor=context.cursor_at_current())

while (yield it.has_next_async()):
if context.should_break_page(len(bugs)):
Expand Down

0 comments on commit 94d0515

Please sign in to comment.