Skip to content

Commit

Permalink
Enable pagination for the join method (#394)
Browse files Browse the repository at this point in the history
* enable output_format

* reorder params

* docstring for 'output_format' param

* fix tabulate for non-json

* add headers as param, fix message when we do not know exact number of records in a query

* fix test

* minor change

* wip: draft

* enable pagination for join method

* add getattr instead of a direct attribute calling

* remove params which are not used, reorder docstring, reduce using of getattr

* roll back the error param

* add a comment

* use hasattr instead of getattr

Co-authored-by: David Caplan <[email protected]>
  • Loading branch information
Nikola Maric and davecap authored Mar 3, 2021
1 parent 42d0ef1 commit 46b1e40
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 33 deletions.
101 changes: 72 additions & 29 deletions solvebio/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,29 @@ def __repr__(self):
if len(self) == 0:
return 'Query returned 0 results.'

return '\n%s\n\n... %s more results.' % (
tabulate(list(self._buffer[0].items()), ['Fields', 'Data'],
aligns=['right', 'left'], sort=True),
pretty_int(len(self) - 1))
placeholder = 'many more results (total unknown)'

if getattr(self, '_is_join', False):
return '\n%s\n\n... %s.' % (
tabulate(list(self._buffer[0].items()), ['Fields', 'Data'],
aligns=['right', 'left'], sort=True),
placeholder)
elif not hasattr(self, '_output_format'):
return '\n%s\n\n... %s more results.' % (
tabulate(list(self._buffer[0].items()), ['Fields', 'Data'],
aligns=['right', 'left'], sort=True),
pretty_int(len(self) - 1))
else:
is_tsv = self._output_format == 'tsv'

# this is the only case when we know the exact number of total records
if len(self) < self._limit:
placeholder = '{} more results'.format(pretty_int(max(len(self) - 9, 0)))

return '\n%s\n\n... %s.' % (
tabulate(list(enumerate(self._buffer[:10])), ['Row', 'Data'],
aligns=['right', 'left'], is_tsv=is_tsv),
placeholder)

def __getattr__(self, key):
if self._response is None:
Expand Down Expand Up @@ -415,12 +434,22 @@ def next(self):
# Iterator not initialized yet
self.__iter__()

# Check if a current object is the join query
_is_join = getattr(self, '_is_join', False)

# len(self) returns `min(limit, total)` results
if self._cursor == len(self):
if not _is_join and self._cursor == len(self):
raise StopIteration

if self._buffer_idx == len(self._buffer):
self.execute(self._page_offset + self._buffer_idx)
if _is_join:
if self._next_offset >= self._limit:
# Since joins can return more results than we expect (due to `explode`)
# manually ensure that we haven't gone above the requested limit (default inf)
raise StopIteration
self.execute(self._next_offset)
else:
self.execute(self._page_offset + self._buffer_idx)
self._buffer_idx = 0

if not self._buffer:
Expand Down Expand Up @@ -512,7 +541,6 @@ def __init__(
target_fields=None,
annotator_params=None,
debug=False,
error=None,
**kwargs):
"""
Creates a new Query object.
Expand All @@ -521,14 +549,14 @@ def __init__(
- `dataset_id`: Unique ID of dataset to query.
- `query` (optional): An optional query string.
- `genome_build`: The genome build to use for the query.
- `result_class` (optional): Class of object returned by query.
- `filters` (optional): Filter or List of filter objects.
- `fields` (optional): List of specific fields to retrieve.
- `exclude_fields` (optional): List of specific fields to exclude.
- `entities` (optional): List of entity tuples to filter on.
- `ordering` (optional): List of fields to order the results by.
- `filters` (optional): Filter or List of filter objects.
- `limit` (optional): Maximum number of query results to return.
- `page_size` (optional): Number of results to fetch per query page.
- `result_class` (optional): Class of object returned by query.
- `target_fields` (optional): Add target fields to annotate the query results.
- `annotator_params` (optional): For use with `target_fields` to adjust annotator parameters.
- `debug` (optional): Sends debug information to the API.
Expand All @@ -537,15 +565,17 @@ def __init__(
self._data_url = '/v2/datasets/{0}/data'.format(dataset_id)
self._query = query
self._genome_build = genome_build
self._result_class = result_class
self._fields = fields
self._exclude_fields = exclude_fields
self._entities = entities
self._ordering = ordering
self._debug = debug
self._error = error
self._result_class = result_class
self._target_fields = target_fields
self._annotator_params = annotator_params
self._debug = debug
self._error = None
self._is_join = False

if filters:
if isinstance(filters, Filter):
filters = [filters]
Expand Down Expand Up @@ -659,7 +689,7 @@ def __len__(self):
SELECT * FROM <table> [WHERE condition] [LIMIT number]
)
"""
if getattr(self, '_is_join', False):
if self._is_join:
return len(self._buffer)

return super(Query, self).__len__()
Expand Down Expand Up @@ -722,6 +752,14 @@ def execute(self, offset=0, **query):
limit=min(self._page_size, self._limit)
)

if self._is_join:
# We do not know the exact total number of records in join because it
# is dynamically calculated in internal expression in target_fields in
# join() method, therefore we have to change limit in the last
# subsequent request in order to get the given number of records from query_a
_params['limit'] = min(self._page_size, abs(self._limit - self._page_offset))
self._next_offset = self._page_offset + min(self._page_size, self._limit)

logger.debug('executing query. from/limit: %6d/%d' %
(_params['offset'], _params['limit']))

Expand Down Expand Up @@ -1031,9 +1069,8 @@ def __init__(
filters=None,
limit=QueryBase.INF,
page_size=DEFAULT_PAGE_SIZE,
result_class=dict,
debug=False,
error=None,
output_format='json',
header=True,
**kwargs):
"""
Creates a new QueryFile object.
Expand All @@ -1043,20 +1080,19 @@ def __init__(
- `fields` (optional): List of specific fields to retrieve.
- `exclude_fields` (optional): List of specific fields to exclude.
- `filters` (optional): Filter or List of filter objects.
- `result_class` (optional): Class of object returned by query.
- `limit` (optional): Maximum number of query results to return.
- `page_size` (optional): Number of results to fetch per query page.
- `debug` (optional): Sends debug information to the API.
- `output_format` (optional): Format of query results (json, csv or tsv)
- `header` (optional): Returns header in response if output_format is 'csv' or 'tsv'
"""
self._file_id = file_id
self._data_url = '/v2/objects/{0}/data'.format(file_id)
self._fields_url = '/v2/objects/{0}/fields'.format(file_id)
self._result_class = result_class
self._debug = debug
self._error = error
self._fields = fields
self._exclude_fields = exclude_fields
self._filters = filters
self._output_format = output_format
self._header = header
self._error = None

if filters:
if isinstance(filters, Filter):
Expand Down Expand Up @@ -1099,9 +1135,9 @@ def _clone(self, filters=None, limit=None):
fields=self._fields,
exclude_fields=self._exclude_fields,
page_size=self._page_size,
result_class=self._result_class,
debug=self._debug,
client=self._client)
output_format=self._output_format,
header=self._header,
client=self._client,)

new._filters += self._filters

Expand All @@ -1114,7 +1150,7 @@ def _clone(self, filters=None, limit=None):
return new

def _build_query(self, **kwargs):
q = {}
q = {'output_format': self._output_format}

if self._filters:
filters = self._process_filters(self._filters)
Expand All @@ -1129,9 +1165,6 @@ def _build_query(self, **kwargs):
if self._exclude_fields is not None:
q['exclude_fields'] = self._exclude_fields

if self._debug:
q['debug'] = 'True'

# Add or modify query parameters
# (used by BatchQuery and facets)
q.update(**kwargs)
Expand Down Expand Up @@ -1159,6 +1192,16 @@ def execute(self, offset=0, **query):
# If the request results in a SolveError (ie bad filter) set the error.
try:
self._response = self._client.post(self._data_url, _params)

if getattr(self, '_header', None) and self._output_format in ('csv', 'tsv') \
and not getattr(self, '_header_fields', None):
self._header_fields = self.fields()

separator_mappings = {'csv': ',', 'tsv': '\t'}
sep = separator_mappings[self._output_format]

self._response['results'].insert(0, sep.join(self._header_fields))

except SolveError as e:
self._error = e
raise
Expand Down
2 changes: 1 addition & 1 deletion solvebio/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_dataset_fields(self):
'is_list', 'entity_type', 'expression',
'name', 'updated_at', 'is_read_only',
'depends_on',
'id', 'url', 'vault_id'])
'id', 'url', 'vault_id', 'url_template'])
self.assertSetEqual(set(dataset_field.keys()), check_fields)

def test_dataset_facets(self):
Expand Down
12 changes: 12 additions & 0 deletions solvebio/test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,15 @@ def test_join_with_list_values(self):
self.assertFalse(isinstance(i['clinical_significance'], list))
self.assertFalse(isinstance(i['b_clinical_significance'], list))
self.assertTrue('_errors' not in i)

def test_join_pagination(self):
# 50 records total
query_a = self.dataset2.query(fields=['gene'], limit=50, page_size=10).filter(gene='MAN2B1')
# 367 records total which have gene='MAN2B1'
query_b = self.dataset2.query(fields=['gene'])

join_query = query_a.join(query_b, key='gene', prefix='b_')

self.assertEqual(len(query_a), 50)
self.assertEqual(len(query_b.filter(gene='MAN2B1')), 367)
self.assertEqual(len(list(join_query)), 50 * 367)
8 changes: 5 additions & 3 deletions solvebio/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def _format_table(fmt, headers, rows, colwidths, colaligns):


def tabulate(tabular_data, headers=[], tablefmt="orgmode",
floatfmt="g", aligns=[], missingval="", sort=True):
floatfmt="g", aligns=[], missingval="", sort=True, is_tsv=False):
list_of_lists, headers = _normalize_tabular_data(tabular_data, headers,
sort=sort)

Expand Down Expand Up @@ -608,9 +608,11 @@ def tabulate(tabular_data, headers=[], tablefmt="orgmode",
if not isinstance(tablefmt, TableFormat):
tablefmt = _table_formats.get(tablefmt, _table_formats["orgmode"])

# make sure values don't have newlines or tabs in them
rows = [[str(c).replace('\n', '').replace('\t', '').replace('\r', '')
# make sure values don't have newlines or tabs in them, except for tsv output_format where
# we have to add spaces in order to simulate tab separators
rows = [[str(c).replace('\n', '').replace('\t', ' ' if is_tsv else '').replace('\r', '')
for c in r] for r in rows]

return _format_table(tablefmt, headers, rows, minwidths, aligns)


Expand Down

0 comments on commit 46b1e40

Please sign in to comment.