Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…n-client into implement-async-client
  • Loading branch information
tsmith023 committed Jul 3, 2024
2 parents 706d525 + 9f1f3af commit b82a79c
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 84 deletions.
110 changes: 69 additions & 41 deletions mock_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from concurrent import futures
from typing import Generator
from typing import Generator, Mapping

import grpc
import pytest
Expand All @@ -12,7 +12,7 @@

import weaviate
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.proto.v1 import tenants_pb2, weaviate_pb2_grpc
from weaviate.proto.v1 import properties_pb2, tenants_pb2, search_get_pb2, weaviate_pb2_grpc

MOCK_IP = "127.0.0.1"
MOCK_PORT = 23536
Expand All @@ -25,7 +25,6 @@
http=ProtocolParams(host=MOCK_IP, port=MOCK_PORT, secure=False),
grpc=ProtocolParams(host=MOCK_IP, port=MOCK_PORT + 1, secure=False),
)
TENANTS_GET_COLLECTION_NAME = "TenantsGetCollectionName"

# pytest_httpserver 'Authorization' HeaderValueMatcher does not work with Bearer tokens.
# Hence, overwrite it with the default header value matcher that just compares for equality.
Expand Down Expand Up @@ -77,48 +76,20 @@ def weaviate_auth_mock(weaviate_mock: HTTPServer):
yield weaviate_mock


# Implement the health check service
class MockHealthServicer(HealthServicer):
def Check(self, request: HealthCheckRequest, context: ServicerContext) -> HealthCheckResponse:
return HealthCheckResponse(status=HealthCheckResponse.SERVING)


class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def TenantsGet(
self, request: tenants_pb2.TenantsGetRequest, context: ServicerContext
) -> tenants_pb2.TenantsGetReply:
return tenants_pb2.TenantsGetReply(
tenants=[
tenants_pb2.Tenant(
name="tenant1", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_HOT
),
tenants_pb2.Tenant(
name="tenant2", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_COLD
),
tenants_pb2.Tenant(
name="tenant3", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FROZEN
),
tenants_pb2.Tenant(
name="tenant4", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FREEZING
),
tenants_pb2.Tenant(
name="tenant5", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFREEZING
),
tenants_pb2.Tenant(
name="tenant6", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFROZEN
),
]
)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def start_grpc_server() -> Generator[grpc.Server, None, None]:
# Create a gRPC server
server: grpc.Server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))

# Implement the health check service
class MockHealthServicer(HealthServicer):
def Check(
self, request: HealthCheckRequest, context: ServicerContext
) -> HealthCheckResponse:
return HealthCheckResponse(status=HealthCheckResponse.SERVING)

# Add the health check service to the server
add_HealthServicer_to_server(MockHealthServicer(), server)
weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), server)

# Listen on a specific port
server.add_insecure_port(f"[::]:{MOCK_PORT_GRPC}")
Expand All @@ -140,5 +111,62 @@ def weaviate_client(


@pytest.fixture(scope="function")
def tenants_collection(weaviate_client: weaviate.WeaviateClient) -> weaviate.collections.Collection:
return weaviate_client.collections.get(TENANTS_GET_COLLECTION_NAME)
def tenants_collection(
weaviate_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
) -> weaviate.collections.Collection:
class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def TenantsGet(
self, request: tenants_pb2.TenantsGetRequest, context: ServicerContext
) -> tenants_pb2.TenantsGetReply:
return tenants_pb2.TenantsGetReply(
tenants=[
tenants_pb2.Tenant(
name="tenant1", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_HOT
),
tenants_pb2.Tenant(
name="tenant2", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_COLD
),
tenants_pb2.Tenant(
name="tenant3", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FROZEN
),
tenants_pb2.Tenant(
name="tenant4", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FREEZING
),
tenants_pb2.Tenant(
name="tenant5",
activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFREEZING,
),
tenants_pb2.Tenant(
name="tenant6", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFROZEN
),
]
)

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
return weaviate_client.collections.get("TenantsGetCollectionName")


@pytest.fixture(scope="function")
def year_zero_collection(
weaviate_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
) -> weaviate.collections.Collection:
class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def Search(
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
) -> search_get_pb2.SearchReply:
zero_date: properties_pb2.Value.date_value = properties_pb2.Value(
date_value="0000-01-30T00:00:00Z"
)
date_prop: Mapping[str, properties_pb2.Value.date_value] = {"date": zero_date}
return search_get_pb2.SearchReply(
results=[
search_get_pb2.SearchResult(
properties=search_get_pb2.PropertiesResult(
non_ref_props=properties_pb2.Properties(fields=date_prop)
)
),
]
)

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
return weaviate_client.collections.get("YearZeroCollection")
41 changes: 3 additions & 38 deletions mock_tests/test_collection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import json
import time
from typing import Any, Dict, Mapping
from typing import Any, Dict

import grpc
import pytest
Expand All @@ -28,7 +28,6 @@
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.connect.integrations import _IntegrationConfig
from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateStartUpError
from weaviate.proto.v1 import weaviate_pb2_grpc, search_get_pb2, properties_pb2

ACCESS_TOKEN = "HELLO!IamAnAccessToken"
REFRESH_TOKEN = "UseMeToRefreshYourAccessToken"
Expand Down Expand Up @@ -343,43 +342,9 @@ def test_integration_config(
weaviate_no_auth_mock.check_assertions()


def test_year_zero(weaviate_no_auth_mock: HTTPServer, start_grpc_server: grpc.Server) -> None:
zero_date: properties_pb2.Value.date_value = properties_pb2.Value(
date_value="0000-01-30T00:00:00Z"
)
date_prop: Mapping[str, properties_pb2.Value.date_value] = {"date": zero_date}

class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def Search(
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
) -> search_get_pb2.SearchReply:
return search_get_pb2.SearchReply(
results=[
search_get_pb2.SearchResult(
properties=search_get_pb2.PropertiesResult(
non_ref_props=properties_pb2.Properties(fields=date_prop)
)
),
]
)

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
schema = {
"class": "Test",
"properties": [],
"vectorizer": "none",
}
weaviate_no_auth_mock.expect_request("/v1/schema/Test").respond_with_json(
response_json=schema, status=200
)

client = weaviate.connect_to_local(
port=MOCK_PORT,
host=MOCK_IP,
grpc_port=MOCK_PORT_GRPC,
)
def test_year_zero(year_zero_collection: weaviate.collections.Collection) -> None:
with pytest.warns(UserWarning) as recwarn:
objs = client.collections.get("Test").query.fetch_objects().objects
objs = year_zero_collection.query.fetch_objects().objects
assert objs[0].properties["date"] == datetime.datetime.min

assert str(recwarn[0].message).startswith("Con004")
21 changes: 18 additions & 3 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ class _RerankerCohereConfig(_RerankerConfigCreate):
model: Optional[Union[RerankerCohereModel, str]] = Field(default=None)


class _RerankerCohereCustom(_RerankerConfigCreate):
class _RerankerCustomConfig(_RerankerConfigCreate):
module_config: Dict[str, Any]

def _to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -572,6 +572,14 @@ def custom(
module_name: str,
module_config: Dict[str, Any],
) -> _GenerativeConfigCreate:
"""Create a `_GenerativeCustom` object for use when generating using a custom module.
Arguments:
`module_name`
The name of the custom module to use, REQUIRED.
`module_config`
The configuration to use for the custom module. Defaults to `None`, which uses the server-defined default.
"""
return _GenerativeCustom(generative=_EnumLikeStr(module_name), module_config=module_config)

@staticmethod
Expand Down Expand Up @@ -863,8 +871,15 @@ def transformers() -> _RerankerConfigCreate:

@staticmethod
def custom(module_name: str, module_config: Dict[str, Any]) -> _RerankerConfigCreate:
"""Create a `_RerankerCohereCustom` object for use when reranking using a custom module."""
return _RerankerCohereCustom(
"""Create a `_RerankerCustomConfig` object for use when reranking using a custom module.
Arguments:
`module_name`
The name of the custom module to use, REQUIRED.
`module_config`
The configuration to use for the custom module. Defaults to `None`, which uses the server-defined default.
"""
return _RerankerCustomConfig(
reranker=_EnumLikeStr(module_name), module_config=module_config
)

Expand Down
9 changes: 8 additions & 1 deletion weaviate/collections/classes/config_vectorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,14 @@ def text2vec_contextionary(vectorize_collection_name: bool = True) -> _Vectorize
def custom(
module_name: str, module_config: Optional[Dict[str, Any]] = None
) -> _VectorizerConfigCreate:
"""Create a `_VectorizerCustomConfig` object for use when vectorizing using a custom model."""
"""Create a `_VectorizerCustomConfig` object for use when vectorizing using a custom module.
Arguments:
`module_name`
The name of the custom module to use, REQUIRED.
`module_config`
The configuration to use for the custom module. Defaults to `None`, which uses the server-defined default.
"""
return _VectorizerCustomConfig(
vectorizer=_EnumLikeStr(module_name), module_config=module_config
)
Expand Down
2 changes: 1 addition & 1 deletion weaviate/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def text2vec_huggingface_endpoint_url_and_model_set_together() -> None:
@staticmethod
def datetime_year_zero(date: str) -> None:
warnings.warn(
message=f"""Con004: Received a date {date} with year 0. The year 0 does not exist in the Gregorian calendar.
message=f"""Con004: Received a date {date} with year 0. The year 0 does not exist in the Gregorian calendar
and cannot be parsed by the datetime library. The year will be set to {datetime.min}.
See https://en.wikipedia.org/wiki/Year_zero for more information.""",
category=UserWarning,
Expand Down

0 comments on commit b82a79c

Please sign in to comment.