Skip to content

Commit

Permalink
Improve multi-sv query combination computation (#2522)
Browse files Browse the repository at this point in the history
Fold in the great feedback from @jehangiramjad on
#2510.

This also fixes a bug in previous implementation.
  • Loading branch information
pradh authored Mar 31, 2023
1 parent 66fd7f9 commit 2bcd7ef
Show file tree
Hide file tree
Showing 3 changed files with 426 additions and 195 deletions.
46 changes: 26 additions & 20 deletions nl_server/query_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Query related helpers"""

from dataclasses import dataclass
import itertools
import logging
import re
from typing import List
Expand Down Expand Up @@ -60,26 +61,31 @@ class QuerySet:
def _prepare_queryset(nsplits: int, query_parts: List[str]) -> QuerySet:
result = QuerySet(nsplits=nsplits, delim_based=False, combinations=[])

if nsplits == 1:
if query_parts:
result.combinations.append(QuerySplit(parts=[' '.join(query_parts)]))
return result

# For example, [A B C D E]. To split in 3 parts, the first part can range
# from 0 till 2, for a length of 3.
num = len(query_parts) - (nsplits - 1)
# Continuing with that example, end ranges from 1 to 3.
for end in range(1, num + 1):
# For a given first, there can be many combinations of rest.
first = ' '.join(query_parts[:end])
rest = _prepare_queryset(nsplits - 1, query_parts[end:])
for qs in rest.combinations:
parts = [first] + qs.parts
if len(parts) != len(set(parts)):
# There are duplicates, ignore this combination.
continue
result.combinations.append(QuerySplit(parts=parts))

assert nsplits >= 2
assert nsplits <= len(query_parts)
#
# For M nsplits on N query_parts, we compute different
# combinations each of which is a (M-1) array of
# "split index" with values ranging from 0 to (N-2).
# The split-index is the last index of a sequence of words.
#
# For e.g., for 3 nsplits of "hispanic poor male population",
# we do combinations(range(3), 2) which gives
# [(0,1), (0,2), (1,2)], which refers to 3 QuerySplits:
# ['hispanic', 'poor', 'male population']
# ['hispanic', 'poor male', 'population']
# ['hispanic poor', 'male', 'population']
#
split_index_combos = itertools.combinations(range(len(query_parts) - 1),
nsplits - 1)
for split_index in split_index_combos:
qs = QuerySplit(parts=[])
start = 0
for last in split_index:
qs.parts.append(' '.join(query_parts[start:last + 1]))
start = last + 1
qs.parts.append(' '.join(query_parts[start:]))
result.combinations.append(qs)
return result


Expand Down
195 changes: 188 additions & 7 deletions nl_server/tests/query_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,204 @@

import unittest

from nl_server import query_util
from parameterized import parameterized

from nl_server.query_util import get_parts_via_delimiters
from nl_server.query_util import prepare_multivar_querysets
from nl_server.query_util import QuerySet
from nl_server.query_util import QuerySplit

class TestQueryUtil(unittest.TestCase):

class TestGetPartsViaDelimiters(unittest.TestCase):

def test_get_parts_via_delimiters_versus(self):
self.assertEqual(['compare male population', 'female population'],
query_util.get_parts_via_delimiters(
'compare male population vs female population'))
self.assertEqual([
'compare male population', 'female population'
], get_parts_via_delimiters('compare male population vs female population'))

def test_get_parts_via_delimiters_list(self):
self.assertEqual(
['male population', 'female population', 'poor people', 'rich people'],
query_util.get_parts_via_delimiters(
get_parts_via_delimiters(
'male population, female population and poor people & rich people'))

def test_get_parts_via_delimiters_doublequotes(self):
self.assertEqual(['male population', 'female population'],
query_util.get_parts_via_delimiters(
get_parts_via_delimiters(
'compare "male population" with "female population"'))


class TestPrepareMultivarQuerysets(unittest.TestCase):

@parameterized.expand([
[
'number of poor hispanic women with phd',
[
QuerySet(
nsplits=2,
delim_based=False,
combinations=[
QuerySplit(parts=['number', 'poor hispanic women phd']),
QuerySplit(parts=['number poor', 'hispanic women phd']),
QuerySplit(parts=['number poor hispanic', 'women phd']),
QuerySplit(parts=['number poor hispanic women', 'phd'])
]),
QuerySet(
nsplits=3,
delim_based=False,
combinations=[
QuerySplit(
parts=['number', 'poor', 'hispanic women phd']),
QuerySplit(
parts=['number', 'poor hispanic', 'women phd']),
QuerySplit(
parts=['number', 'poor hispanic women', 'phd']),
QuerySplit(
parts=['number poor', 'hispanic', 'women phd']),
QuerySplit(
parts=['number poor', 'hispanic women', 'phd']),
QuerySplit(parts=['number poor hispanic', 'women', 'phd'])
]),
QuerySet(
nsplits=4,
delim_based=False,
combinations=[
QuerySplit(
parts=['number', 'poor', 'hispanic', 'women phd']),
QuerySplit(
parts=['number', 'poor', 'hispanic women', 'phd']),
QuerySplit(
parts=['number', 'poor hispanic', 'women', 'phd']),
QuerySplit(
parts=['number poor', 'hispanic', 'women', 'phd'])
])
],
],
[
'compare obesity vs. poverty',
[
QuerySet(nsplits=2,
delim_based=True,
combinations=[
QuerySplit(parts=['compare obesity', 'poverty'])
]),
QuerySet(
nsplits=3,
delim_based=False,
combinations=[
QuerySplit(parts=['compare', 'obesity', 'vs poverty']),
QuerySplit(parts=['compare', 'obesity vs', 'poverty']),
QuerySplit(parts=['compare obesity', 'vs', 'poverty'])
]),
QuerySet(
nsplits=4,
delim_based=False,
combinations=[
QuerySplit(parts=['compare', 'obesity', 'vs', 'poverty'])
])
],
],
[
'show me the impact of climate change on drought',
[
QuerySet(nsplits=2,
delim_based=False,
combinations=[
QuerySplit(
parts=['show', 'impact climate change drought']),
QuerySplit(
parts=['show impact', 'climate change drought']),
QuerySplit(
parts=['show impact climate', 'change drought']),
QuerySplit(
parts=['show impact climate change', 'drought'])
]),
QuerySet(
nsplits=3,
delim_based=False,
combinations=[
QuerySplit(
parts=['show', 'impact', 'climate change drought']),
QuerySplit(
parts=['show', 'impact climate', 'change drought']),
QuerySplit(
parts=['show', 'impact climate change', 'drought']),
QuerySplit(
parts=['show impact', 'climate', 'change drought']),
QuerySplit(
parts=['show impact', 'climate change', 'drought']),
QuerySplit(
parts=['show impact climate', 'change', 'drought'])
]),
QuerySet(
nsplits=4,
delim_based=False,
combinations=[
QuerySplit(
parts=['show', 'impact', 'climate', 'change drought'
]),
QuerySplit(
parts=['show', 'impact', 'climate change', 'drought'
]),
QuerySplit(
parts=['show', 'impact climate', 'change', 'drought'
]),
QuerySplit(
parts=['show impact', 'climate', 'change', 'drought'])
])
]
],
[
'Compare "Male population" with "Female Population"',
[
QuerySet(
nsplits=2,
delim_based=True,
combinations=[
QuerySplit(parts=['male population', 'female population'])
]),
QuerySet(nsplits=3,
delim_based=False,
combinations=[
QuerySplit(parts=[
'compare', 'male', 'population female population'
]),
QuerySplit(parts=[
'compare', 'male population', 'female population'
]),
QuerySplit(parts=[
'compare', 'male population female', 'population'
]),
QuerySplit(parts=[
'compare male', 'population', 'female population'
]),
QuerySplit(parts=[
'compare male', 'population female', 'population'
]),
QuerySplit(parts=[
'compare male population', 'female', 'population'
])
]),
QuerySet(
nsplits=4,
delim_based=False,
combinations=[
QuerySplit(parts=[
'compare', 'male', 'population', 'female population'
]),
QuerySplit(parts=[
'compare', 'male', 'population female', 'population'
]),
QuerySplit(parts=[
'compare', 'male population', 'female', 'population'
]),
QuerySplit(parts=[
'compare male', 'population', 'female', 'population'
])
])
]
]
])
def test_prepare_multivar_querysets(self, query, expected):
self.maxDiff = None
self.assertEqual(prepare_multivar_querysets(query), expected)
Loading

0 comments on commit 2bcd7ef

Please sign in to comment.