Skip to content

Commit

Permalink
Merge pull request #452 from weaviate/hotfix-where-filtering
Browse files Browse the repository at this point in the history
ensure client functionality is aligned with docs usage
  • Loading branch information
dirkkul authored Aug 29, 2023
2 parents 0d3344d + 397a794 commit 54053fb
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 62 deletions.
179 changes: 162 additions & 17 deletions integration/test_graphql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import uuid
from typing import Optional, List
from typing import Optional, List, Union

import pytest
from pytest import FixtureRequest
Expand All @@ -20,6 +20,7 @@
{"dataType": ["string"], "description": "name", "name": "name"},
{"dataType": ["string"], "description": "description", "name": "description"},
{"dataType": ["int"], "description": "size", "name": "size"},
{"dataType": ["number"], "description": "rating", "name": "rating"},
],
"vectorizer": "text2vec-contextionary",
}
Expand All @@ -28,30 +29,50 @@

SHIPS = [
{
"props": {"name": "HMS British Name", "size": 5, "description": "Super long description"},
"props": {
"name": "HMS British Name",
"size": 5,
"rating": 0.0,
"description": "Super long description",
},
"id": uuid.uuid4(),
},
{
"props": {
"name": "The dragon ship",
"rating": 6.66,
"size": 20,
"description": "Interesting things about dragons",
},
"id": uuid.uuid4(),
},
{
"props": {"name": "Blackbeard", "size": 43, "description": "Background info about movies"},
"props": {
"name": "Blackbeard",
"size": 43,
"rating": 7.2,
"description": "Background info about movies",
},
"id": uuid.uuid4(),
},
{"props": {"name": "Titanic", "size": 1, "description": "Everyone knows"}, "id": uuid.uuid4()},
{
"props": {"name": "Artemis", "size": 34, "description": "Name from some story"},
"props": {"name": "Titanic", "size": 1, "rating": 4.5, "description": "Everyone knows"},
"id": uuid.uuid4(),
},
{
"props": {
"name": "Artemis",
"size": 34,
"rating": 9.1,
"description": "Name from some story",
},
"id": uuid.uuid4(),
},
{
"props": {
"name": "The Crusty Crab",
"size": 303,
"rating": 10.0,
"description": "sponges, sponges, sponges",
},
"id": uuid.uuid4(),
Expand Down Expand Up @@ -113,33 +134,157 @@ def test_get_data(client: weaviate.Client):

def test_get_data_with_where_contains_any(client: weaviate.Client):
"""Test GraphQL's Get clause with where filter."""
where_filter = {"path": ["size"], "operator": "ContainsAny", "valueIntArray": [5]}
where_filter = {"path": ["size"], "operator": "ContainsAny", "valueInt": [5]}
result = client.query.get("Ship", ["name", "size"]).with_where(where_filter).do()
objects = get_objects_from_result(result)
assert len(objects) == 1 and objects[0]["name"] == "HMS British Name"


@pytest.mark.parametrize(
"value_string_list,expected_objects_len",
"path,operator,value_type_key,value_type_value,name,expected_objects_len",
[
(["sponges, sponges, sponges"], 1),
(["sponges, sponges, sponges", "doesn't exist"], 0),
(
["description"],
"ContainsAll",
"valueString",
["sponges, sponges, sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAll",
"valueText",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAll",
"valueStringArray",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAll",
"valueTextArray",
["sponges, sponges, sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAll",
"valueStringList",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAll",
"valueTextList",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAny",
"valueString",
["sponges, sponges, sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAny",
"valueText",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAny",
"valueStringArray",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAny",
"valueTextArray",
["sponges, sponges, sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAny",
"valueStringList",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(
["description"],
"ContainsAny",
"valueTextList",
["sponges", "sponges", "sponges"],
"The Crusty Crab",
1,
),
(["size"], "ContainsAll", "valueInt", [5], "HMS British Name", 1),
(["size"], "ContainsAll", "valueIntList", [5], "HMS British Name", 1),
(["size"], "ContainsAll", "valueIntArray", [5], "HMS British Name", 1),
(["size"], "ContainsAny", "valueInt", [5], "HMS British Name", 1),
(["size"], "ContainsAny", "valueIntList", [5], "HMS British Name", 1),
(["size"], "ContainsAny", "valueIntArray", [5], "HMS British Name", 1),
(["rating"], "ContainsAll", "valueNumber", [6.66], "The dragon ship", 1),
(["rating"], "ContainsAll", "valueNumberList", [6.66], "The dragon ship", 1),
(["rating"], "ContainsAll", "valueNumberArray", [6.66], "The dragon ship", 1),
(["rating"], "ContainsAny", "valueNumber", [6.66], "The dragon ship", 1),
(["rating"], "ContainsAny", "valueNumberList", [6.66], "The dragon ship", 1),
(["rating"], "ContainsAny", "valueNumberArray", [6.66], "The dragon ship", 1),
(["size"], "Equal", "valueInt", 5, "HMS British Name", 1),
(["size"], "LessThan", "valueInt", 5, "Titanic", 1),
(["size"], "LessThanEqual", "valueInt", 1, "Titanic", 1),
(["size"], "GreaterThan", "valueInt", 300, "The Crusty Crab", 1),
(["size"], "GreaterThanEqual", "valueInt", 303, "The Crusty Crab", 1),
(["description"], "Like", "valueString", "sponges", "The Crusty Crab", 1),
(["description"], "Like", "valueText", "sponges", "The Crusty Crab", 1),
(["rating"], "IsNull", "valueBoolean", True, "irrelevant", 0),
(["rating"], "NotEqual", "valueNumber", 6.66, "irrelevant", 5),
],
)
def test_get_data_with_where_contains_all(
client: weaviate.Client, value_string_list: List[str], expected_objects_len: int
def test_get_data_with_where(
client: weaviate.Client,
path: List[str],
operator: str,
value_type_key: str,
value_type_value: Union[List[int], List[str]],
name,
expected_objects_len: int,
):
"""Test GraphQL's Get clause with where filter."""
where_filter = {
"path": ["description"],
"operator": "ContainsAll",
"valueStringArray": value_string_list,
"path": path,
"operator": operator,
value_type_key: value_type_value,
}
result = client.query.get("Ship", ["name"]).with_where(where_filter).do()
objects = get_objects_from_result(result)
assert len(objects) == expected_objects_len
if expected_objects_len == 1:
assert objects[0]["name"] == "The Crusty Crab"
if expected_objects_len == 0:
assert objects is None
else:
assert len(objects) == expected_objects_len
if expected_objects_len == 1:
assert objects[0]["name"] == name


def test_get_data_after(client: weaviate.Client):
Expand Down
23 changes: 0 additions & 23 deletions test/gql/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,6 @@ def test___str__(self):
value_is_not_list_err = (
lambda v, t: f"Must provide a list when constructing where filter for {t} with {v}"
)
value_is_list_err = (
lambda v, t: f"Cannot provide a list when constructing where filter for {t} with {v}"
)

test_filter = {"path": ["name"], "operator": "Equal", "valueString": "A"}
result = str(Where(test_filter))
Expand Down Expand Up @@ -761,26 +758,6 @@ def test___str__(self):
str(Where(test_filter))
check_error_message(self, error, value_is_not_list_err("A", "valueTextArray"))

test_filter = {
"path": ["name"],
"operator": "GreaterThan",
"valueInt": [1, 2],
}
with self.assertRaises(TypeError) as error:
str(Where(test_filter))
check_error_message(self, error, value_is_list_err([1, 2], "valueInt"))

test_filter = {
"path": ["name"],
"operator": "Equal",
"valueDate": ["test-2021-02-02", "test-2021-02-03"],
}
with self.assertRaises(TypeError) as error:
str(Where(test_filter))
check_error_message(
self, error, value_is_list_err(["test-2021-02-02", "test-2021-02-03"], "valueDate")
)

test_filter = {
"path": ["name"],
"operator": "Equal",
Expand Down
64 changes: 42 additions & 22 deletions weaviate/gql/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,33 +835,53 @@ def _parse_operator(self, content: dict) -> None:
def __str__(self):
if self.is_filter:
gql = f"where: {{path: {self.path} operator: {self.operator} {_convert_value_type(self.value_type)}: "
if self.value_type in ["valueInt", "valueNumber"]:
_check_is_not_list(self.value, self.value_type)
gql += f"{self.value}}}"
elif self.value_type in ["valueIntArray", "valueNumberArray"]:
_check_is_list(self.value, self.value_type)
if self.value_type in [
"valueInt",
"valueNumber",
"valueIntArray",
"valueNumberArray",
"valueIntList",
"valueNumberList",
]:
if self.value_type in [
"valueIntList",
"valueNumberList",
"valueIntList",
"valueNumberList",
]:
_check_is_list(self.value, self.value_type)
gql += f"{self.value}}}"
elif self.value_type in ["valueText", "valueString"]:
_check_is_not_list(self.value, self.value_type)
gql += f"{_sanitize_str(self.value)}}}"
elif self.value_type in ["valueTextArray", "valueStringArray"]:
_check_is_list(self.value, self.value_type)
val = [_sanitize_str(v) for v in self.value]
gql += f"{_render_list(val)}}}"
elif self.value_type == "valueBoolean":
_check_is_not_list(self.value, self.value_type)
gql += f"{_bool_to_str(self.value)}}}"
elif self.value_type == "valueBooleanArray":
_check_is_list(self.value, self.value_type)
gql += f"{_render_list(self.value)}}}"
elif self.value_type == "valueDateArray":
_check_is_list(self.value, self.value_type)
gql += f"{_render_list(self.value)}}}"
elif self.value_type in [
"valueText",
"valueString",
"valueTextList",
"valueStringList",
"valueTextArray",
"valueStringArray",
]:
if self.value_type in [
"valueTextList",
"valueStringList",
"valueTextArray",
"valueStringArray",
]:
_check_is_list(self.value, self.value_type)
if isinstance(self.value, list):
val = [_sanitize_str(v) for v in self.value]
gql += f"{_render_list(val)}}}"
else:
gql += f"{_sanitize_str(self.value)}}}"
elif self.value_type in ["valueBoolean", "valueBooleanArray", "valueBooleanList"]:
if self.value_type in ["valueBooleanArray", "valueBooleanList"]:
_check_is_list(self.value, self.value_type)
if isinstance(self.value, list):
gql += f"{_render_list(self.value)}}}"
else:
gql += f"{_bool_to_str(self.value)}}}"
elif self.value_type == "valueGeoRange":
_check_is_not_list(self.value, self.value_type)
gql += f"{_geo_range_to_str(self.value)}}}"
else:
_check_is_not_list(self.value, self.value_type)
gql += f'"{self.value}"}}'
return gql + " "

Expand Down

0 comments on commit 54053fb

Please sign in to comment.