Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional props for grpc #326

Merged
merged 2 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
- --scheme
- http
- --write-timeout=600s
image: semitechnologies/weaviate:preview-hybrid-bm25-for-grpc-f0ed9dc
image: semitechnologies/weaviate:preview-add-additional-properties-to-grpc-39c4c45
ports:
- "8080:8080"
- "50051:50051"
Expand Down
2 changes: 1 addition & 1 deletion integration/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import weaviate

GIT_HASH = "f0ed9dc"
GIT_HASH = "39c4c45"
SERVER_VERSION = "1.19.2"
NODE_NAME = "node1"
NUM_OBJECT = 10
Expand Down
37 changes: 36 additions & 1 deletion integration/test_grcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
{"name": "ref", "dataType": ["Test"]},
],
}
VECTOR = [1, 2, 3] * 100 # match with vectorizer vector length
VECTOR = [1.5, 2.5, 3.5] * 100 # match with vectorizer vector length


UUID1 = "577887c1-4c6b-5594-aa62-f0c17883d9bf"
Expand Down Expand Up @@ -100,3 +100,38 @@ def test_grcp(
result = query.do()
assert "Test2" in result["data"]["Get"]
assert "test" in result["data"]["Get"]["Test2"][0]


def test_additional():
client_grpc = weaviate.Client(
"http://localhost:8080",
additional_config=Config(grpc_port_experimental=50051),
)
client_grpc.schema.delete_all()

client_grpc.schema.create_class(CLASS1)
client_grpc.data_object.create({"test": "test"}, "Test", vector=VECTOR)
client_gql = weaviate.Client("http://localhost:8080")

results = []
for client in [client_gql, client_grpc]:
query = client.query.get("Test").with_additional(
weaviate.AdditionalProperties(
uuid=True,
vector=True,
creationTimeUnix=True,
lastUpdateTimeUnix=True,
distance=True,
)
)
result = query.do()
assert "Test" in result["data"]["Get"]

results.append(result)

result_gql = results[0]["data"]["Get"]["Test"][0]["_additional"]
result_grpc = results[1]["data"]["Get"]["Test"][0]["_additional"]

assert sorted(result_gql.keys()) == sorted(result_grpc.keys())
for key in result_gql.keys():
assert result_gql[key] == result_grpc[key]
20 changes: 19 additions & 1 deletion test/gql/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,30 @@

from test.util import check_error_message
from weaviate.data.replication import ConsistencyLevel
from weaviate.gql.get import GetBuilder, BM25, Hybrid, Reference, GroupBy
from weaviate.gql.get import GetBuilder, BM25, Hybrid, Reference, GroupBy, AdditionalProperties

mock_connection_v117 = Mock()
mock_connection_v117.server_version = "1.17.4"


@pytest.mark.parametrize(
"props,expected",
[
(AdditionalProperties(uuid=True), "_additional{id}"),
(
AdditionalProperties(uuid=True, vector=True, explainScore=True),
"_additional{id vector explainScore}",
),
(
AdditionalProperties(uuid=False, vector=True, explainScore=True, score=True),
"_additional{vector score explainScore}",
),
],
)
def test_additional_props(props: AdditionalProperties, expected: str):
assert str(props) == expected


@pytest.mark.parametrize(
"query,properties,expected",
[
Expand Down
4 changes: 3 additions & 1 deletion weaviate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"EmbeddedOptions",
"Config",
"ConnectionConfig",
"AdditionalProperties",
"Reference",
]

import sys
Expand All @@ -72,7 +74,7 @@
WeaviateStartUpError,
)
from .config import Config, ConnectionConfig

from .gql.get import AdditionalProperties, Reference

if not sys.warnoptions:
import warnings
Expand Down
16 changes: 15 additions & 1 deletion weaviate/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ class MissingScopeException(WeaviateBaseError):
"""Scope was not provided with client credential flow."""


class AdditionalPropertiesException(WeaviateBaseError):
"""Additional properties were provided multiple times."""

def __init__(self, additional_dict: str, additional_dataclass: str):
msg = f"""
Cannot add AdditionalProperties class together with string-additional properties. Did you call
.with_additional() multiple times?.
Current additional properties already present:
- strings: {additional_dict}
- AdditionalProperties class: {additional_dataclass}
"""
super().__init__(msg)


class WeaviateStartUpError(WeaviateBaseError):
"""Is raised if weaviate does not start up in time."""

Expand All @@ -115,7 +129,7 @@ class WeaviateEmbeddedInvalidVersion(WeaviateBaseError):
"""Invalid version provided to Weaviate embedded."""

def __init__(self, url: str):
msg = """Invalid version provided to Weaviate embedded. It must be either:
msg = f"""Invalid version provided to Weaviate embedded. It must be either:
- a url to a tar.gz file that contains a Weaviate binary
- a version number, eg "1.18.2"
- the string "latest" to download the latest non-beta version
Expand Down
104 changes: 96 additions & 8 deletions weaviate/gql/get.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
GraphQL `Get` command.
"""
from dataclasses import dataclass
from dataclasses import dataclass, Field, fields
from json import dumps
from typing import List, Union, Optional, Dict, Tuple

from weaviate import util
from weaviate.connect import Connection
from weaviate.data.replication import ConsistencyLevel
from weaviate.exceptions import AdditionalPropertiesException
from weaviate.gql.filter import (
Where,
NearText,
Expand Down Expand Up @@ -87,6 +88,34 @@ def __str__(self) -> str:
PROPERTIES = Union[List[Union[str, Reference]], str]


@dataclass
class AdditionalProperties:
uuid: bool = False
vector: bool = False
creationTimeUnix: bool = False
lastUpdateTimeUnix: bool = False
distance: bool = False
certainty: bool = False
score: bool = False
explainScore: bool = False

def __str__(self) -> str:
additional_props: List[str] = []
cls_fields: Tuple[Field, ...] = fields(self.__class__)
for field in cls_fields:
if issubclass(field.type, bool):
enabled: bool = getattr(self, field.name)
if enabled:
name = field.name
if field.name == "uuid": # id is reserved python name
name = "id"
additional_props.append(name)
if len(additional_props) > 0:
return "_additional{" + " ".join(additional_props) + "}"
else:
return ""


class GetBuilder(GraphQL):
"""
GetBuilder class used to create GraphQL queries.
Expand Down Expand Up @@ -134,6 +163,7 @@ def __init__(self, class_name: str, properties: Optional[PROPERTIES], connection
self._additional: dict = {"__one_level": set()}
# '__one_level' refers to the additional properties that are just a single word, not a dict
# thus '__one_level', only one level of complexity
self._additional_dataclass: Optional[AdditionalProperties] = None
self._where: Optional[Where] = None # To store the where filter if it is added
self._limit: Optional[int] = None # To store the limit filter if it is added
self._offset: Optional[str] = None # To store the offset filter if it is added
Expand Down Expand Up @@ -657,7 +687,10 @@ def with_ask(self, content: dict) -> "GetBuilder":
return self

def with_additional(
self, properties: Union[List, str, Dict[str, Union[List[str], str]], Tuple[dict, dict]]
self,
properties: Union[
List, str, Dict[str, Union[List[str], str]], Tuple[dict, dict], AdditionalProperties
],
) -> "GetBuilder":
"""
Add additional properties (i.e. properties from `_additional` clause). See Examples below.
Expand Down Expand Up @@ -832,6 +865,17 @@ def with_additional(
TypeError
If one of the property is not of a correct data type.
"""
if isinstance(properties, AdditionalProperties):
if len(self._additional) > 1 or len(self._additional["__one_level"]) > 0:
raise AdditionalPropertiesException(
str(self._additional), str(self._additional_dataclass)
)
self._additional_dataclass = properties
return self
elif self._additional_dataclass is not None:
raise AdditionalPropertiesException(
str(self._additional), str(self._additional_dataclass)
)

if isinstance(properties, str):
self._additional["__one_level"].add(properties)
Expand Down Expand Up @@ -1197,7 +1241,17 @@ def do(self) -> dict:
if self._near_ask is not None and isinstance(self._near_ask, NearObject)
else None,
properties=self._convert_references_to_grpc(self._properties),
additional_properties=self._additional["__one_level"],
additional_properties=weaviate_pb2.AdditionalProperties(
uuid=self._additional_dataclass.uuid,
vector=self._additional_dataclass.vector,
creationTimeUnix=self._additional_dataclass.creationTimeUnix,
lastUpdateTimeUnix=self._additional_dataclass.lastUpdateTimeUnix,
distance=self._additional_dataclass.distance,
explainScore=self._additional_dataclass.explainScore,
score=self._additional_dataclass.score,
)
if self._additional_dataclass is not None
else None,
bm25_search=weaviate_pb2.BM25SearchParams(
properties=self._bm25.properties, query=self._bm25.query
)
Expand All @@ -1218,17 +1272,49 @@ def do(self) -> dict:
objects = []
for result in res.results:
obj = self._convert_references_to_grpc_result(result.properties)
if len(self._additional["__one_level"]) > 0:
obj["_additional"] = {}
if "id" in self._additional["__one_level"]:
obj["_additional"]["id"] = result.additional_properties.id
additional = self._extract_additional_properties(result.additional_properties)
if len(additional) > 0:
obj["_additional"] = additional
objects.append(obj)

results = {"data": {"Get": {self._class_name: objects}}}
return results
else:
return super().do()

def _extract_additional_properties(
self, props: weaviate_pb2.ResultAdditionalProps
) -> Dict[str, str]:
additional_props = {}
if self._additional_dataclass is None:
return additional_props

if self._additional_dataclass.uuid:
additional_props["id"] = props.id
if self._additional_dataclass.vector:
additional_props["vector"] = (
[float(num) for num in props.vector] if len(props.vector) > 0 else None
)
if self._additional_dataclass.distance:
additional_props["distance"] = props.distance if props.distance_present else None
if self._additional_dataclass.certainty:
additional_props["certainty"] = props.certainty if props.certainty_present else None
if self._additional_dataclass.creationTimeUnix:
additional_props["creationTimeUnix"] = (
str(props.creation_time_unix) if props.creation_time_unix_present else None
)
if self._additional_dataclass.lastUpdateTimeUnix:
additional_props["lastUpdateTimeUnix"] = (
str(props.last_update_time_unix) if props.last_update_time_unix_present else None
)
if self._additional_dataclass.score:
additional_props["score"] = props.score if props.score_present else None
if self._additional_dataclass.explainScore:
additional_props["explainScore"] = (
props.explain_score if props.explain_score_present else None
)
return additional_props

def _convert_references_to_grpc_result(self, properties: weaviate_pb2.ResultProperties) -> Dict:
result = {}
for name, non_ref_prop in properties.non_ref_properties.items():
Expand All @@ -1248,7 +1334,7 @@ def _convert_references_to_grpc(
non_ref_properties=[prop for prop in properties if isinstance(prop, str)],
ref_properties=[
weaviate_pb2.RefProperties(
class_name=prop.linked_class,
linked_class=prop.linked_class,
reference_property=prop.reference_property,
linked_properties=self._convert_references_to_grpc(prop.properties),
)
Expand All @@ -1266,6 +1352,8 @@ def _additional_to_str(self) -> str:
str
The converted self._additional.
"""
if self._additional_dataclass is not None:
return str(self._additional_dataclass)

str_to_return = " _additional {"

Expand Down
Loading