diff --git a/integration/test_graphql.py b/integration/test_graphql.py index 760e844ca..4c257c030 100644 --- a/integration/test_graphql.py +++ b/integration/test_graphql.py @@ -1,7 +1,7 @@ import json import os import uuid -from typing import Optional, List +from typing import Optional, List, Union import pytest from pytest import FixtureRequest @@ -20,6 +20,7 @@ {"dataType": ["string"], "description": "name", "name": "name"}, {"dataType": ["string"], "description": "description", "name": "description"}, {"dataType": ["int"], "description": "size", "name": "size"}, + {"dataType": ["number"], "description": "rating", "name": "rating"}, ], "vectorizer": "text2vec-contextionary", } @@ -28,30 +29,50 @@ SHIPS = [ { - "props": {"name": "HMS British Name", "size": 5, "description": "Super long description"}, + "props": { + "name": "HMS British Name", + "size": 5, + "rating": 0.0, + "description": "Super long description", + }, "id": uuid.uuid4(), }, { "props": { "name": "The dragon ship", + "rating": 6.66, "size": 20, "description": "Interesting things about dragons", }, "id": uuid.uuid4(), }, { - "props": {"name": "Blackbeard", "size": 43, "description": "Background info about movies"}, + "props": { + "name": "Blackbeard", + "size": 43, + "rating": 7.2, + "description": "Background info about movies", + }, "id": uuid.uuid4(), }, - {"props": {"name": "Titanic", "size": 1, "description": "Everyone knows"}, "id": uuid.uuid4()}, { - "props": {"name": "Artemis", "size": 34, "description": "Name from some story"}, + "props": {"name": "Titanic", "size": 1, "rating": 4.5, "description": "Everyone knows"}, + "id": uuid.uuid4(), + }, + { + "props": { + "name": "Artemis", + "size": 34, + "rating": 9.1, + "description": "Name from some story", + }, "id": uuid.uuid4(), }, { "props": { "name": "The Crusty Crab", "size": 303, + "rating": 10.0, "description": "sponges, sponges, sponges", }, "id": uuid.uuid4(), @@ -113,33 +134,157 @@ def test_get_data(client: weaviate.Client): def test_get_data_with_where_contains_any(client: weaviate.Client): """Test GraphQL's Get clause with where filter.""" - where_filter = {"path": ["size"], "operator": "ContainsAny", "valueIntArray": [5]} + where_filter = {"path": ["size"], "operator": "ContainsAny", "valueInt": [5]} result = client.query.get("Ship", ["name", "size"]).with_where(where_filter).do() objects = get_objects_from_result(result) assert len(objects) == 1 and objects[0]["name"] == "HMS British Name" @pytest.mark.parametrize( - "value_string_list,expected_objects_len", + "path,operator,value_type_key,value_type_value,name,expected_objects_len", [ - (["sponges, sponges, sponges"], 1), - (["sponges, sponges, sponges", "doesn't exist"], 0), + ( + ["description"], + "ContainsAll", + "valueString", + ["sponges, sponges, sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAll", + "valueText", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAll", + "valueStringArray", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAll", + "valueTextArray", + ["sponges, sponges, sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAll", + "valueStringList", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAll", + "valueTextList", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAny", + "valueString", + ["sponges, sponges, sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAny", + "valueText", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAny", + "valueStringArray", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAny", + "valueTextArray", + ["sponges, sponges, sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAny", + "valueStringList", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + ( + ["description"], + "ContainsAny", + "valueTextList", + ["sponges", "sponges", "sponges"], + "The Crusty Crab", + 1, + ), + (["size"], "ContainsAll", "valueInt", [5], "HMS British Name", 1), + (["size"], "ContainsAll", "valueIntList", [5], "HMS British Name", 1), + (["size"], "ContainsAll", "valueIntArray", [5], "HMS British Name", 1), + (["size"], "ContainsAny", "valueInt", [5], "HMS British Name", 1), + (["size"], "ContainsAny", "valueIntList", [5], "HMS British Name", 1), + (["size"], "ContainsAny", "valueIntArray", [5], "HMS British Name", 1), + (["rating"], "ContainsAll", "valueNumber", [6.66], "The dragon ship", 1), + (["rating"], "ContainsAll", "valueNumberList", [6.66], "The dragon ship", 1), + (["rating"], "ContainsAll", "valueNumberArray", [6.66], "The dragon ship", 1), + (["rating"], "ContainsAny", "valueNumber", [6.66], "The dragon ship", 1), + (["rating"], "ContainsAny", "valueNumberList", [6.66], "The dragon ship", 1), + (["rating"], "ContainsAny", "valueNumberArray", [6.66], "The dragon ship", 1), + (["size"], "Equal", "valueInt", 5, "HMS British Name", 1), + (["size"], "LessThan", "valueInt", 5, "Titanic", 1), + (["size"], "LessThanEqual", "valueInt", 1, "Titanic", 1), + (["size"], "GreaterThan", "valueInt", 300, "The Crusty Crab", 1), + (["size"], "GreaterThanEqual", "valueInt", 303, "The Crusty Crab", 1), + (["description"], "Like", "valueString", "sponges", "The Crusty Crab", 1), + (["description"], "Like", "valueText", "sponges", "The Crusty Crab", 1), + (["rating"], "IsNull", "valueBoolean", True, "irrelevant", 0), + (["rating"], "NotEqual", "valueNumber", 6.66, "irrelevant", 5), ], ) -def test_get_data_with_where_contains_all( - client: weaviate.Client, value_string_list: List[str], expected_objects_len: int +def test_get_data_with_where( + client: weaviate.Client, + path: List[str], + operator: str, + value_type_key: str, + value_type_value: Union[List[int], List[str]], + name, + expected_objects_len: int, ): """Test GraphQL's Get clause with where filter.""" where_filter = { - "path": ["description"], - "operator": "ContainsAll", - "valueStringArray": value_string_list, + "path": path, + "operator": operator, + value_type_key: value_type_value, } result = client.query.get("Ship", ["name"]).with_where(where_filter).do() objects = get_objects_from_result(result) - assert len(objects) == expected_objects_len - if expected_objects_len == 1: - assert objects[0]["name"] == "The Crusty Crab" + if expected_objects_len == 0: + assert objects is None + else: + assert len(objects) == expected_objects_len + if expected_objects_len == 1: + assert objects[0]["name"] == name def test_get_data_after(client: weaviate.Client): diff --git a/test/gql/test_filter.py b/test/gql/test_filter.py index aedb5d533..f112c58ff 100644 --- a/test/gql/test_filter.py +++ b/test/gql/test_filter.py @@ -626,9 +626,6 @@ def test___str__(self): value_is_not_list_err = ( lambda v, t: f"Must provide a list when constructing where filter for {t} with {v}" ) - value_is_list_err = ( - lambda v, t: f"Cannot provide a list when constructing where filter for {t} with {v}" - ) test_filter = {"path": ["name"], "operator": "Equal", "valueString": "A"} result = str(Where(test_filter)) @@ -761,26 +758,6 @@ def test___str__(self): str(Where(test_filter)) check_error_message(self, error, value_is_not_list_err("A", "valueTextArray")) - test_filter = { - "path": ["name"], - "operator": "GreaterThan", - "valueInt": [1, 2], - } - with self.assertRaises(TypeError) as error: - str(Where(test_filter)) - check_error_message(self, error, value_is_list_err([1, 2], "valueInt")) - - test_filter = { - "path": ["name"], - "operator": "Equal", - "valueDate": ["test-2021-02-02", "test-2021-02-03"], - } - with self.assertRaises(TypeError) as error: - str(Where(test_filter)) - check_error_message( - self, error, value_is_list_err(["test-2021-02-02", "test-2021-02-03"], "valueDate") - ) - test_filter = { "path": ["name"], "operator": "Equal", diff --git a/weaviate/gql/filter.py b/weaviate/gql/filter.py index 1e5d965fe..71d785cfa 100644 --- a/weaviate/gql/filter.py +++ b/weaviate/gql/filter.py @@ -835,33 +835,53 @@ def _parse_operator(self, content: dict) -> None: def __str__(self): if self.is_filter: gql = f"where: {{path: {self.path} operator: {self.operator} {_convert_value_type(self.value_type)}: " - if self.value_type in ["valueInt", "valueNumber"]: - _check_is_not_list(self.value, self.value_type) - gql += f"{self.value}}}" - elif self.value_type in ["valueIntArray", "valueNumberArray"]: - _check_is_list(self.value, self.value_type) + if self.value_type in [ + "valueInt", + "valueNumber", + "valueIntArray", + "valueNumberArray", + "valueIntList", + "valueNumberList", + ]: + if self.value_type in [ + "valueIntList", + "valueNumberList", + "valueIntList", + "valueNumberList", + ]: + _check_is_list(self.value, self.value_type) gql += f"{self.value}}}" - elif self.value_type in ["valueText", "valueString"]: - _check_is_not_list(self.value, self.value_type) - gql += f"{_sanitize_str(self.value)}}}" - elif self.value_type in ["valueTextArray", "valueStringArray"]: - _check_is_list(self.value, self.value_type) - val = [_sanitize_str(v) for v in self.value] - gql += f"{_render_list(val)}}}" - elif self.value_type == "valueBoolean": - _check_is_not_list(self.value, self.value_type) - gql += f"{_bool_to_str(self.value)}}}" - elif self.value_type == "valueBooleanArray": - _check_is_list(self.value, self.value_type) - gql += f"{_render_list(self.value)}}}" - elif self.value_type == "valueDateArray": - _check_is_list(self.value, self.value_type) - gql += f"{_render_list(self.value)}}}" + elif self.value_type in [ + "valueText", + "valueString", + "valueTextList", + "valueStringList", + "valueTextArray", + "valueStringArray", + ]: + if self.value_type in [ + "valueTextList", + "valueStringList", + "valueTextArray", + "valueStringArray", + ]: + _check_is_list(self.value, self.value_type) + if isinstance(self.value, list): + val = [_sanitize_str(v) for v in self.value] + gql += f"{_render_list(val)}}}" + else: + gql += f"{_sanitize_str(self.value)}}}" + elif self.value_type in ["valueBoolean", "valueBooleanArray", "valueBooleanList"]: + if self.value_type in ["valueBooleanArray", "valueBooleanList"]: + _check_is_list(self.value, self.value_type) + if isinstance(self.value, list): + gql += f"{_render_list(self.value)}}}" + else: + gql += f"{_bool_to_str(self.value)}}}" elif self.value_type == "valueGeoRange": _check_is_not_list(self.value, self.value_type) gql += f"{_geo_range_to_str(self.value)}}}" else: - _check_is_not_list(self.value, self.value_type) gql += f'"{self.value}"}}' return gql + " "