From e38be0c3c9d72648f6b0566b45e7ac58cdccd879 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Tue, 5 Sep 2023 14:43:13 +0100 Subject: [PATCH 01/18] begin fixing compile errors --- weaviate/backup/backup.py | 3 +- weaviate/batch/crud_batch.py | 89 ++++++++++++++++------- weaviate/batch/requests.py | 7 +- weaviate/classification/classification.py | 4 +- weaviate/classification/config_builder.py | 12 +-- weaviate/connect/connection.py | 4 +- weaviate/util.py | 9 ++- 7 files changed, 84 insertions(+), 44 deletions(-) diff --git a/weaviate/backup/backup.py b/weaviate/backup/backup.py index d5a33a532..c76dd7ed6 100644 --- a/weaviate/backup/backup.py +++ b/weaviate/backup/backup.py @@ -118,6 +118,7 @@ def create( ) from conn_err create_status = _decode_json_response_dict(response, "Backup creation") + assert create_status is not None if wait_for_completion: while True: status: dict = self.get_create_status( @@ -246,7 +247,7 @@ def restore( "Backup restore failed due to connection error." ) from conn_err restore_status = _decode_json_response_dict(response, "Backup restore") - + assert restore_status is not None if wait_for_completion: while True: status: dict = self.get_restore_status( diff --git a/weaviate/batch/crud_batch.py b/weaviate/batch/crud_batch.py index de2db3cb7..fa950fedc 100644 --- a/weaviate/batch/crud_batch.py +++ b/weaviate/batch/crud_batch.py @@ -7,10 +7,24 @@ import time import warnings from collections import deque -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor, as_completed, Future from dataclasses import dataclass from numbers import Real -from typing import Tuple, Callable, Optional, Sequence, Union, List +from typing import ( + Any, + Callable, + Deque, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) from requests import ReadTimeout, Response from requests.exceptions import ConnectionError as RequestsConnectionError @@ -257,17 +271,17 @@ def __init__(self, connection: Connection): self._objects_batch = ObjectsBatchRequest() self._reference_batch = ReferenceBatchRequest() # do not keep too many past values, so it is a better estimation of the throughput is computed for 1 second - self._objects_throughput_frame = deque(maxlen=5) - self._references_throughput_frame = deque(maxlen=5) - self._future_pool = [] - self._reference_batch_queue = [] + self._objects_throughput_frame: Deque[float] = deque(maxlen=5) + self._references_throughput_frame: Deque[float] = deque(maxlen=5) + self._future_pool: List[Future[Tuple[Response | None, int]]] = [] + self._reference_batch_queue: List[ReferenceBatchRequest] = [] self._callback_lock = threading.Lock() # user configurable, need to be public should implement a setter/getter self._callback: Optional[Callable[[BatchResponse], None]] = check_batch_result self._weaviate_error_retry: Optional[WeaviateErrorRetryConf] = None self._batch_size: Optional[int] = 50 - self._creation_time = min(self._connection.timeout_config[1] / 10, 2) + self._creation_time = cast(Real, min(self._connection.timeout_config[1] / 10, 2)) self._timeout_retries = 3 self._connection_error_retries = 3 self._batching_type: Optional[str] = "dynamic" @@ -275,11 +289,11 @@ def __init__(self, connection: Connection): self._recommended_num_references = self._batch_size self._num_workers = 1 - self._consistency_level = None + self._consistency_level: Optional[Literal["ALL", "ONE", "QUORUM"]] = None # thread pool executor self._executor: Optional[BatchExecutor] = None - def __call__(self, **kwargs) -> "Batch": + def __call__(self, **kwargs: Any) -> "Batch": """ WARNING: This method will be deprecated in the next major release. Use `configure` instead. @@ -334,7 +348,7 @@ def configure( timeout_retries: int = 3, connection_error_retries: int = 3, weaviate_error_retries: Optional[WeaviateErrorRetryConf] = None, - callback: Optional[Callable[[dict], None]] = check_batch_result, + callback: Optional[Callable[[List[dict]], None]] = check_batch_result, dynamic: bool = True, num_workers: int = 1, consistency_level: Optional[ConsistencyLevel] = None, @@ -393,7 +407,7 @@ def configure( _check_positive_num(creation_time, "creation_time", Real) self._creation_time = creation_time else: - self._creation_time = min(self._connection.timeout_config[1] / 10, 2) + self._creation_time = cast(Real, min(self._connection.timeout_config[1] / 10, 2)) _check_non_negative(timeout_retries, "timeout_retries", int) _check_non_negative(connection_error_retries, "connection_error_retries", int) @@ -433,13 +447,16 @@ def configure( self._auto_create() return self - def _update_recommended_batch_size(self): + def _update_recommended_batch_size(self) -> None: """Create a background thread that periodically checks how congested the batch queue is.""" self._shutdown_background_event = threading.Event() - def periodic_check(): + def periodic_check() -> None: cluster = Cluster(self._connection) - while not self._shutdown_background_event.is_set(): + while ( + self._shutdown_background_event is not None + and not self._shutdown_background_event.is_set() + ): try: status = cluster.get_nodes_status() if "stats" not in status[0] or "ratePerSecond" not in status[0]["stats"]: @@ -468,7 +485,7 @@ def periodic_check(): else: # way too high, stop sending new batches self._recommended_num_objects = 0 - refresh_time = 2 + refresh_time: Union[float, int] = 2 except (RequestsHTTPError, ReadTimeout): refresh_time = 0.1 @@ -652,7 +669,7 @@ def _create_data( weaviate.UnexpectedStatusCodeException If weaviate reports a none OK status. """ - params = {} + params: Dict[str, str] = {} if self._consistency_level is not None: params["consistency_level"] = self._consistency_level @@ -691,6 +708,7 @@ def _create_data( connection_count += 1 else: response_json = _decode_json_response_list(response, "batch response") + assert response_json is not None if ( self._weaviate_error_retry is not None and batch_error_count < self._weaviate_error_retry.number_retries @@ -721,7 +739,7 @@ def _create_data( return response raise UnexpectedStatusCodeException(f"Create {data_type} in batch", response) - def _run_callback(self, response: BatchResponse): + def _run_callback(self, response: BatchResponse) -> None: if self._callback is None: return # We don't know if user-supplied functions are threadsafe @@ -748,8 +766,10 @@ def _batch_retry_after_timeout( """ if data_type == "objects": + assert isinstance(batch_request, ObjectsBatchRequest) return self._readd_objects_after_timeout(batch_request) else: + assert isinstance(batch_request, ReferenceBatchRequest) return self._readd_references_after_timeout(batch_request) def _readd_objects_after_timeout( @@ -792,6 +812,7 @@ def _readd_objects_after_timeout( ) obj_weav = _decode_json_response_dict(response, "Re-add objects") + assert obj_weav is not None if obj_weav["properties"] != obj["properties"] or obj.get( "vector", None ) != obj_weav.get("vector", None): @@ -935,7 +956,9 @@ class NonExistingClass not present" self._recommended_num_objects = max(round(obj_per_second * self._creation_time), 1) - return _decode_json_response_list(response, "batch add objects") + res = _decode_json_response_list(response, "batch add objects") + assert res is not None + return res return [] def create_references(self) -> list: @@ -1030,7 +1053,9 @@ def create_references(self) -> list: self._recommended_num_references = round(ref_per_sec * self._creation_time) - return _decode_json_response_list(response, "Create references") + res = _decode_json_response_list(response, "Create references") + assert res is not None + return res return [] def _flush_in_thread( @@ -1092,6 +1117,7 @@ def _send_batch_requests(self, force_wait: bool) -> None: ) self.start() + assert self._executor is not None future = self._executor.submit( self._flush_in_thread, data_type="objects", @@ -1188,6 +1214,7 @@ def _auto_create(self) -> None: # greater or equal in case the self._batch_size is changed manually if self._batching_type == "fixed": + assert self._batch_size is not None if sum(self.shape) >= self._batch_size: self._send_batch_requests(force_wait=False) return @@ -1302,7 +1329,7 @@ def delete_objects( if not isinstance(dry_run, bool): raise TypeError(f"'dry_run' must be of type bool. Given type: {type(dry_run)}.") - params = {} + params: Dict[str, str] = {} if self._consistency_level is not None: params["consistency_level"] = self._consistency_level if tenant is not None: @@ -1325,7 +1352,9 @@ def delete_objects( ) except RequestsConnectionError as conn_err: raise RequestsConnectionError("Batch delete was not successful.") from conn_err - return _decode_json_response_dict(response, "Delete in batch") + res = _decode_json_response_dict(response, "Delete in batch") + assert res is not None + return res def num_objects(self) -> int: """ @@ -1534,11 +1563,11 @@ def dynamic(self, value: bool) -> None: self._auto_create() @property - def consistency_level(self, value: Optional[Union[ConsistencyLevel, None]]) -> Union[str, None]: + def consistency_level(self) -> Union[str, None]: return self._consistency_level @consistency_level.setter - def consistency_level(self, x: Optional[Union[ConsistencyLevel, None]]) -> None: + def consistency_level(self, x: Optional[ConsistencyLevel]) -> None: self._consistency_level = ConsistencyLevel(x).value if x else None @property @@ -1600,7 +1629,7 @@ def shutdown(self) -> None: def __enter__(self) -> "Batch": return self.start() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.flush() self.shutdown() @@ -1710,9 +1739,10 @@ def _retry_on_error( self, response: BatchResponse, data_type: str ) -> Tuple[BatchRequestType, BatchResponse]: if data_type == "objects": - new_batch = ObjectsBatchRequest() + new_batch: Union[ObjectsBatchRequest, ReferenceBatchRequest] = ObjectsBatchRequest() else: new_batch = ReferenceBatchRequest() + assert self._weaviate_error_retry is not None successful_responses = new_batch.add_failed_objects_from_response( response, self._weaviate_error_retry.errors_to_exclude, @@ -1721,17 +1751,20 @@ def _retry_on_error( return new_batch, successful_responses -def _check_non_negative(value: Real, arg_name: str, data_type: type) -> None: +N = TypeVar("N", bound=Union[int, float, Real]) + + +def _check_non_negative(value: N, arg_name: str, data_type: Type[N]) -> None: """ Check if the `value` of the `arg_name` is a non-negative number. Parameters ---------- - value : Union[int, float] + value : N (int, float, Real) The value to check. arg_name : str The name of the variable from the original function call. Used for error message. - data_type : type + data_type : Type[N] The data type to check for. Raises diff --git a/weaviate/batch/requests.py b/weaviate/batch/requests.py index 955c108f9..357653205 100644 --- a/weaviate/batch/requests.py +++ b/weaviate/batch/requests.py @@ -7,6 +7,7 @@ from uuid import uuid4 from weaviate.util import get_valid_uuid, get_vector +from weaviate.types import UUID BatchResponse = List[Dict[str, Any]] @@ -131,9 +132,9 @@ class ReferenceBatchRequest(BatchRequest): def add( self, from_object_class_name: str, - from_object_uuid: str, + from_object_uuid: UUID, from_property_name: str, - to_object_uuid: str, + to_object_uuid: UUID, to_object_class_name: Optional[str] = None, tenant: Optional[str] = None, ) -> None: @@ -238,7 +239,7 @@ def add( self, data_object: dict, class_name: str, - uuid: Optional[str] = None, + uuid: Optional[UUID] = None, vector: Optional[Sequence] = None, tenant: Optional[str] = None, ) -> str: diff --git a/weaviate/classification/classification.py b/weaviate/classification/classification.py index 0dc74531f..9ee886f0a 100644 --- a/weaviate/classification/classification.py +++ b/weaviate/classification/classification.py @@ -74,7 +74,9 @@ def get(self, classification_uuid: str) -> dict: "Classification status could not be retrieved." ) from conn_err - return _decode_json_response_dict(response, "Get classification status") + res = _decode_json_response_dict(response, "Get classification status") + assert res is not None + return res def is_complete(self, classification_uuid: str) -> bool: """ diff --git a/weaviate/classification/config_builder.py b/weaviate/classification/config_builder.py index d8c8d7d68..4beb87a55 100644 --- a/weaviate/classification/config_builder.py +++ b/weaviate/classification/config_builder.py @@ -2,13 +2,13 @@ ConfigBuilder class definition. """ import time -from typing import Dict, Any +from typing import Dict, Any, cast from requests.exceptions import ConnectionError as RequestsConnectionError from weaviate.connect import Connection from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.util import _capitalize_first_letter +from weaviate.util import _capitalize_first_letter, _decode_json_response_dict class ConfigBuilder: @@ -16,7 +16,7 @@ class ConfigBuilder: ConfigBuild class that is used to configure a classification process. """ - def __init__(self, connection: Connection, classification: "Classification"): # noqa + def __init__(self, connection: Connection, classification: "Classification"): # type: ignore # noqa """ Initialize a ConfigBuilder class instance. @@ -270,7 +270,9 @@ def _start(self) -> dict: except RequestsConnectionError as conn_err: raise RequestsConnectionError("Classification may not started.") from conn_err if response.status_code == 201: - return response.json() + res = _decode_json_response_dict(response, "Start classification") + assert res is not None + return res raise UnexpectedStatusCodeException("Start classification", response) def do(self) -> dict: @@ -294,4 +296,4 @@ def do(self) -> dict: # print(classification_uuid) while self._classification.is_running(classification_uuid): time.sleep(2.0) - return self._classification.get(classification_uuid) + return cast(dict, self._classification.get(classification_uuid)) diff --git a/weaviate/connect/connection.py b/weaviate/connect/connection.py index 4f22196fb..6ad338806 100644 --- a/weaviate/connect/connection.py +++ b/weaviate/connect/connection.py @@ -8,7 +8,7 @@ import socket import time from threading import Thread, Event -from typing import Any, Dict, Tuple, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import requests @@ -421,7 +421,7 @@ def patch( def post( self, path: str, - weaviate_object: dict, + weaviate_object: Union[List[dict], dict], params: Optional[Dict[str, Any]] = None, ) -> requests.Response: """ diff --git a/weaviate/util.py b/weaviate/util.py index d0b8d8f56..c8e08ec4b 100644 --- a/weaviate/util.py +++ b/weaviate/util.py @@ -7,7 +7,7 @@ import re from enum import Enum, EnumMeta from io import BufferedReader -from typing import Union, Sequence, Any, Optional, List, Dict, Tuple +from typing import Union, Sequence, Any, Optional, List, Dict, Tuple, TypeVar import requests import uuid as uuid_lib @@ -582,9 +582,10 @@ def check_batch_result( print(result["result"]["errors"]) -def _check_positive_num( - value: NUMBERS, arg_name: str, data_type: type, include_zero: bool = False -) -> None: +N = TypeVar("N") + + +def _check_positive_num(value: N, arg_name: str, data_type: N, include_zero: bool = False) -> None: """ Check if the `value` of the `arg_name` is a positive number. From 1912e8a98a8402eac2b079485253c2dfa37b67f1 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Tue, 5 Sep 2023 17:00:47 +0100 Subject: [PATCH 02/18] fix majority of mypy errors --- test/gql/test_query.py | 2 +- weaviate/batch/requests.py | 2 +- weaviate/classification/config_builder.py | 7 ++-- weaviate/cluster/cluster.py | 5 +-- weaviate/connect/authentication.py | 7 ++-- weaviate/connect/connection.py | 38 +++++++++++--------- weaviate/contextionary/crud_contextionary.py | 4 ++- weaviate/data/crud_data.py | 22 ++++++------ weaviate/embedded.py | 1 + weaviate/exceptions.py | 2 +- weaviate/gql/filter.py | 32 +++++++++-------- weaviate/gql/get.py | 30 ++++++++-------- weaviate/gql/query.py | 4 ++- weaviate/schema/crud_schema.py | 17 +++++---- weaviate/util.py | 28 ++++++++------- weaviate/warnings.py | 6 ++-- 16 files changed, 117 insertions(+), 90 deletions(-) diff --git a/test/gql/test_query.py b/test/gql/test_query.py index 71a999689..f009332a8 100644 --- a/test/gql/test_query.py +++ b/test/gql/test_query.py @@ -35,7 +35,7 @@ def test_raw(self): """ # valid calls - connection_mock = mock_connection_func("post") + connection_mock = mock_connection_func("post", return_json={}) query = Query(connection_mock) gql_query = "{Get {Group {name Members {... on Person {name}}}}}" diff --git a/weaviate/batch/requests.py b/weaviate/batch/requests.py index 357653205..e5937eb99 100644 --- a/weaviate/batch/requests.py +++ b/weaviate/batch/requests.py @@ -65,7 +65,7 @@ def pop(self, index: int = -1) -> dict: return self._items.pop(index) @abstractmethod - def add(self, *args, **kwargs): + def add(self, *args, **kwargs): # type: ignore """Add objects to BatchRequest.""" @abstractmethod diff --git a/weaviate/classification/config_builder.py b/weaviate/classification/config_builder.py index 4beb87a55..7deed5cb4 100644 --- a/weaviate/classification/config_builder.py +++ b/weaviate/classification/config_builder.py @@ -2,7 +2,7 @@ ConfigBuilder class definition. """ import time -from typing import Dict, Any, cast +from typing import Dict, Any, cast, TYPE_CHECKING from requests.exceptions import ConnectionError as RequestsConnectionError @@ -10,13 +10,16 @@ from weaviate.exceptions import UnexpectedStatusCodeException from weaviate.util import _capitalize_first_letter, _decode_json_response_dict +if TYPE_CHECKING: + from .classification import Classification + class ConfigBuilder: """ ConfigBuild class that is used to configure a classification process. """ - def __init__(self, connection: Connection, classification: "Classification"): # type: ignore # noqa + def __init__(self, connection: Connection, classification: "Classification"): """ Initialize a ConfigBuilder class instance. diff --git a/weaviate/cluster/cluster.py b/weaviate/cluster/cluster.py index 20024a3c8..63b95063c 100644 --- a/weaviate/cluster/cluster.py +++ b/weaviate/cluster/cluster.py @@ -1,7 +1,7 @@ """ Cluster class definition. """ -from typing import Optional +from typing import Optional, cast from requests.exceptions import ConnectionError as RequestsConnectionError @@ -64,7 +64,8 @@ def get_nodes_status(self, class_name: Optional[str] = None) -> list: ) from conn_err response_typed = _decode_json_response_dict(response, "Nodes status") + assert response_typed is not None nodes = response_typed.get("nodes") if nodes is None or nodes == []: raise EmptyResponseException("Nodes status response returned empty") - return nodes + return cast(list, nodes) diff --git a/weaviate/connect/authentication.py b/weaviate/connect/authentication.py index f9f2f5624..cb50e33bd 100644 --- a/weaviate/connect/authentication.py +++ b/weaviate/connect/authentication.py @@ -17,7 +17,7 @@ from ..warnings import _Warnings if TYPE_CHECKING: - from .connection import BaseConnection + from .connection import Connection AUTH_DEFAULT_TIMEOUT = 5 OIDC_CONFIG = Dict[str, Union[str, List[str]]] @@ -28,10 +28,10 @@ def __init__( self, oidc_config: OIDC_CONFIG, credentials: AuthCredentials, - connection: BaseConnection, + connection: Connection, ) -> None: self._credentials: AuthCredentials = credentials - self._connection: BaseConnection = connection + self._connection: Connection = connection config_url = oidc_config["href"] client_id = oidc_config["clientId"] assert isinstance(config_url, str) and isinstance(client_id, str) @@ -67,6 +67,7 @@ def _validate(self, oidc_config: OIDC_CONFIG) -> None: def _get_token_endpoint(self) -> str: response_auth = requests.get(self._open_id_config_url, proxies=self._connection.proxies) response_auth_json = _decode_json_response_dict(response_auth, "Get token endpoint") + assert response_auth_json is not None token_endpoint = response_auth_json["token_endpoint"] assert isinstance(token_endpoint, str) return token_endpoint diff --git a/weaviate/connect/connection.py b/weaviate/connect/connection.py index 6ad338806..3be4a2bc2 100644 --- a/weaviate/connect/connection.py +++ b/weaviate/connect/connection.py @@ -8,7 +8,7 @@ import socket import time from threading import Thread, Event -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from urllib.parse import urlparse import requests @@ -46,7 +46,7 @@ except ImportError: has_grpc = False - +JSONPayload = Union[dict, list] Session = Union[requests.sessions.Session, OAuth2Session] TIMEOUT_TYPE_RETURN = Tuple[NUMBERS, NUMBERS] PYPI_TIMEOUT = 1 @@ -255,11 +255,11 @@ def get_current_bearer_token(self) -> str: if "authorization" in self._headers: return self._headers["authorization"] elif isinstance(self._session, OAuth2Session): - return "Bearer " + self._session.token["access_token"] + return f"Bearer {self._session.token['access_token']}" return "" - def _add_adapter_to_session(self, connection_config: ConnectionConfig): + def _add_adapter_to_session(self, connection_config: ConnectionConfig) -> None: adapter = HTTPAdapter( pool_connections=connection_config.session_pool_connections, pool_maxsize=connection_config.session_pool_maxsize, @@ -267,12 +267,13 @@ def _add_adapter_to_session(self, connection_config: ConnectionConfig): self._session.mount("http://", adapter) self._session.mount("https://", adapter) - def _create_background_token_refresh(self, _auth: Optional[_Auth] = None): + def _create_background_token_refresh(self, _auth: Optional[_Auth] = None) -> None: """Create a background thread that periodically refreshes access and refresh tokens. While the underlying library refreshes tokens, it does not have an internal cronjob that checks every X-seconds if a token has expired. If there is no activity for longer than the refresh tokens lifetime, it will expire. Therefore, refresh manually shortly before expiration time is up.""" + assert isinstance(self._session, OAuth2Session) if "refresh_token" not in self._session.token and _auth is None: return @@ -281,9 +282,12 @@ def _create_background_token_refresh(self, _auth: Optional[_Auth] = None): ) # use 1minute as token lifetime if not supplied self._shutdown_background_event = Event() - def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth]): + def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth]) -> None: time.sleep(max(refresh_time - 30, 1)) - while not self._shutdown_background_event.is_set(): + while ( + self._shutdown_background_event is not None + and not self._shutdown_background_event.is_set() + ): # use refresh token when available try: if "refresh_token" in self._session.token: @@ -313,7 +317,7 @@ def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth]): ) demon.start() - def close(self): + def close(self) -> None: """Shutdown connection class gracefully.""" # in case an exception happens before definition of these members if ( @@ -338,7 +342,7 @@ def _get_request_header(self) -> dict: def delete( self, path: str, - weaviate_object: dict = None, + weaviate_object: Optional[JSONPayload] = None, params: Optional[Dict[str, Any]] = None, ) -> requests.Response: """ @@ -380,7 +384,7 @@ def delete( def patch( self, path: str, - weaviate_object: dict, + weaviate_object: JSONPayload, params: Optional[Dict[str, Any]] = None, ) -> requests.Response: """ @@ -421,7 +425,7 @@ def patch( def post( self, path: str, - weaviate_object: Union[List[dict], dict], + weaviate_object: JSONPayload, params: Optional[Dict[str, Any]] = None, ) -> requests.Response: """ @@ -464,7 +468,7 @@ def post( def put( self, path: str, - weaviate_object: dict, + weaviate_object: JSONPayload, params: Optional[Dict[str, Any]] = None, ) -> requests.Response: """ @@ -605,7 +609,7 @@ def timeout_config(self) -> TIMEOUT_TYPE_RETURN: return self._timeout_config @timeout_config.setter - def timeout_config(self, timeout_config: TIMEOUT_TYPE_RETURN): + def timeout_config(self, timeout_config: TIMEOUT_TYPE_RETURN) -> None: """ Setter for `timeout_config`. (docstring should be only in the Getter) """ @@ -616,7 +620,7 @@ def timeout_config(self, timeout_config: TIMEOUT_TYPE_RETURN): def proxies(self) -> dict: return self._proxies - def wait_for_weaviate(self, startup_period: Optional[int]): + def wait_for_weaviate(self, startup_period: Optional[int]) -> None: """ Waits until weaviate is ready or the timelimit given in 'startup_period' has passed. @@ -632,7 +636,7 @@ def wait_for_weaviate(self, startup_period: Optional[int]): """ ready_url = self.url + self._api_version_path + "/.well-known/ready" - for _i in range(startup_period): + for _i in range(startup_period or 30): try: requests.get(ready_url, headers=self._get_request_header()).raise_for_status() return @@ -663,7 +667,9 @@ def get_meta(self) -> Dict[str, str]: Returns the meta endpoint. """ response = self.get(path="/meta") - return _decode_json_response_dict(response, "Meta endpoint") + res = _decode_json_response_dict(response, "Meta endpoint") + assert res is not None + return res def _get_epoch_time() -> int: diff --git a/weaviate/contextionary/crud_contextionary.py b/weaviate/contextionary/crud_contextionary.py index e28db060d..fa083a6dc 100644 --- a/weaviate/contextionary/crud_contextionary.py +++ b/weaviate/contextionary/crud_contextionary.py @@ -152,4 +152,6 @@ def get_concept_vector(self, concept: str) -> dict: "text2vec-contextionary vector was not retrieved." ) from conn_err else: - return _decode_json_response_dict(response, "text2vec-contextionary vector") + res = _decode_json_response_dict(response, "text2vec-contextionary vector") + assert res is not None + return res diff --git a/weaviate/data/crud_data.py b/weaviate/data/crud_data.py index 418cd9eed..e3aa72ab9 100644 --- a/weaviate/data/crud_data.py +++ b/weaviate/data/crud_data.py @@ -3,7 +3,7 @@ """ import uuid as uuid_lib import warnings -from typing import Union, Optional, List, Sequence, Dict, Any +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from requests.exceptions import ConnectionError as RequestsConnectionError @@ -369,7 +369,7 @@ def _create_object_for_update( class_name: str, uuid: Union[str, uuid_lib.UUID], vector: Optional[Sequence] = None, - ): + ) -> Tuple[Dict[str, Any], str]: if not isinstance(class_name, str): raise TypeError("Class must be type str") @@ -397,7 +397,7 @@ def _create_object_for_update( def get_by_id( self, uuid: Union[str, uuid_lib.UUID], - additional_properties: List[str] = None, + additional_properties: Optional[List[str]] = None, with_vector: bool = False, class_name: Optional[str] = None, node_name: Optional[str] = None, @@ -475,7 +475,7 @@ def get_by_id( def get( self, uuid: Union[str, uuid_lib.UUID, None] = None, - additional_properties: List[str] = None, + additional_properties: Optional[List[str]] = None, with_vector: bool = False, class_name: Optional[str] = None, node_name: Optional[str] = None, @@ -628,15 +628,15 @@ def get( raise TypeError( f"'sort['order_asc']' must be of type boolean or list[bool]. Given type: {type(sort['order_asc'])}." ) - if len(sort["properties"]) != len(sort["order_asc"]): + if len(sort["properties"]) != len(sort["order_asc"]): # type: ignore raise ValueError( - f"'sort['order_asc']' must be the same length as 'sort['properties']' or a boolean (not in a list). Current length is sort['properties']:{len(sort['properties'])} and sort['order_asc']:{len(sort['order_asc'])}." + f"'sort['order_asc']' must be the same length as 'sort['properties']' or a boolean (not in a list). Current length is sort['properties']:{len(sort['properties'])} and sort['order_asc']:{len(sort['order_asc'])}." # type: ignore ) - if len(sort["order_asc"]) == 0: + if len(sort["order_asc"]) == 0: # type: ignore raise ValueError("'sort['order_asc']' cannot be an empty list.") - params["sort"] = ",".join(sort["properties"]) - order = ["asc" if x else "desc" for x in sort["order_asc"]] + params["sort"] = ",".join(sort["properties"]) # type: ignore + order = ["asc" if x else "desc" for x in sort["order_asc"]] # type: ignore params["order"] = ",".join(order) try: @@ -647,7 +647,7 @@ def get( except RequestsConnectionError as conn_err: raise RequestsConnectionError("Could not get object/s.") from conn_err if response.status_code == 200: - return response.json() + return cast(Dict[str, Any], response.json()) if response.status_code == 404: return None raise UnexpectedStatusCodeException("Get object/s", response) @@ -864,7 +864,7 @@ def validate( data_object: Union[dict, str], class_name: str, uuid: Union[str, uuid_lib.UUID, None] = None, - vector: Sequence = None, + vector: Optional[Sequence] = None, ) -> dict: """ Validate an object against Weaviate. diff --git a/weaviate/embedded.py b/weaviate/embedded.py index b7ae6d58a..9161187ff 100644 --- a/weaviate/embedded.py +++ b/weaviate/embedded.py @@ -84,6 +84,7 @@ def __init__(self, options: EmbeddedOptions) -> None: "https://api.github.com/repos/weaviate/weaviate/releases/latest" ) latest = _decode_json_response_dict(response, "get tag of latest weaviate release") + assert latest is not None self._set_download_url_from_version_tag(latest["tag_name"]) else: raise exceptions.WeaviateEmbeddedInvalidVersion(self.options.version) diff --git a/weaviate/exceptions.py b/weaviate/exceptions.py index c217e33fd..702224d4b 100644 --- a/weaviate/exceptions.py +++ b/weaviate/exceptions.py @@ -84,7 +84,7 @@ def __init__(self, location: str, response: Response): response: requests.Response The request response of which the status code was unexpected. """ - msg = f"Cannot decode response from weaviate {response} with content {response.content} for request from {location}" + msg = f"Cannot decode response from weaviate {response} with content {response.text} for request from {location}" super().__init__(msg) self._status_code: int = response.status_code diff --git a/weaviate/gql/filter.py b/weaviate/gql/filter.py index e75a23645..6d4f78ca7 100644 --- a/weaviate/gql/filter.py +++ b/weaviate/gql/filter.py @@ -7,7 +7,7 @@ from copy import deepcopy from enum import Enum from json import dumps -from typing import Any, Union +from typing import Any, Tuple, Union from requests.exceptions import ConnectionError as RequestsConnectionError @@ -121,7 +121,9 @@ def do(self) -> dict: except RequestsConnectionError as conn_err: raise RequestsConnectionError("Query was not successful.") from conn_err - return _decode_json_response_dict(response, "Query was not successful") + res = _decode_json_response_dict(response, "Query was not successful") + assert res is not None + return res class Filter(ABC): @@ -153,7 +155,7 @@ def __str__(self) -> str: """ @property - def content(self): + def content(self) -> dict: return self._content @@ -203,7 +205,7 @@ def __init__(self, content: dict): if "autocorrect" in self._content: _check_type(var_name="autocorrect", value=self._content["autocorrect"], dtype=bool) - def __str__(self): + def __str__(self) -> str: near_text = f'nearText: {{concepts: {dumps(self._content["concepts"])}' if "certainty" in self._content: near_text += f' certainty: {self._content["certainty"]}' @@ -276,7 +278,7 @@ def __init__(self, content: dict): self._content["vector"] = get_vector(self._content["vector"]) - def __str__(self): + def __str__(self) -> str: near_vector = f'nearVector: {{vector: {dumps(self._content["vector"])}' if "certainty" in self._content: near_vector += f' certainty: {self._content["certainty"]}' @@ -339,7 +341,7 @@ def __init__(self, content: dict, is_server_version_14: bool): if "distance" in self._content: _check_type(var_name="distance", value=self._content["distance"], dtype=float) - def __str__(self): + def __str__(self) -> str: near_object = f'nearObject: {{{self.obj_id}: "{self._content[self.obj_id]}"' if "certainty" in self._content: near_object += f' certainty: {self._content["certainty"]}' @@ -399,7 +401,7 @@ def __init__(self, content: dict): if isinstance(self._content["properties"], str): self._content["properties"] = [self._content["properties"]] - def __str__(self): + def __str__(self) -> str: ask = f'ask: {{question: {dumps(self._content["question"])}' if "certainty" in self._content: ask += f' certainty: {self._content["certainty"]}' @@ -458,7 +460,7 @@ def __init__( if "distance" in self._content: _check_type(var_name="distance", value=self._content["distance"], dtype=float) - def __str__(self): + def __str__(self) -> str: media = self._media_type.value.capitalize() if self._media_type == MediaType.IMU: media = self._media_type.value.upper() @@ -830,7 +832,7 @@ def _parse_operator(self, content: dict) -> None: for operand in _content["operands"]: self.operands.append(Where(operand)) - def __str__(self): + def __str__(self) -> str: if self.is_filter: gql = f"where: {{path: {self.path} operator: {self.operator} {_convert_value_type(self.value_type)}: " if self.value_type in [ @@ -938,7 +940,7 @@ def _render_list(value: list) -> str: return f'[{",".join(value)}]' -def _check_is_list(value: Any, _type: str): +def _check_is_list(value: Any, _type: str) -> None: """Checks whether the provided value is a list to match the given `value_type`. Parameters @@ -959,7 +961,7 @@ def _check_is_list(value: Any, _type: str): ) -def _check_is_not_list(value: Any, _type: str): +def _check_is_not_list(value: Any, _type: str) -> None: """Checks whether the provided value is a list to match the given `value_type`. Parameters @@ -1020,7 +1022,7 @@ def _bool_to_str(value: bool) -> str: return "false" -def _check_direction_clause(direction: dict) -> dict: +def _check_direction_clause(direction: dict) -> None: """ Validate the direction sub clause. @@ -1113,7 +1115,7 @@ def _check_objects(content: dict) -> None: ) -def _check_type(var_name: str, value: Any, dtype: type) -> None: +def _check_type(var_name: str, value: Any, dtype: Union[Tuple[type, type], type]) -> None: """ Check key-value type. @@ -1121,9 +1123,9 @@ def _check_type(var_name: str, value: Any, dtype: type) -> None: ---------- var_name : str The variable name for which to check the type (used for error message)! - value : Any + value : T The value for which to check the type. - dtype : type + dtype : T The expected data type of the `value`. Raises diff --git a/weaviate/gql/get.py b/weaviate/gql/get.py index e502c6d0e..0f46f5fbb 100644 --- a/weaviate/gql/get.py +++ b/weaviate/gql/get.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, Field, fields from enum import Enum from json import dumps -from typing import List, Union, Optional, Dict, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from weaviate import util from weaviate.connect import Connection @@ -204,9 +204,9 @@ def __init__(self, class_name: str, properties: Optional[PROPERTIES], connection self._alias: Optional[str] = None self._tenant: Optional[str] = None self._autocut: Optional[int] = None - self._consistency_level: Optional[ConsistencyLevel] = None + self._consistency_level: Optional[str] = None - def with_autocut(self, autocut: int): + def with_autocut(self, autocut: int) -> "GetBuilder": """Cuts off irrelevant results based on "jumps" in scores.""" if not isinstance(autocut, int): raise TypeError("autocut must be of type int") @@ -215,7 +215,7 @@ def with_autocut(self, autocut: int): self._contains_filter = True return self - def with_tenant(self, tenant: str): + def with_tenant(self, tenant: str) -> "GetBuilder": """Sets a tenant for the query.""" if not isinstance(tenant, str): raise TypeError("tenant must be of type str") @@ -224,12 +224,12 @@ def with_tenant(self, tenant: str): self._contains_filter = True return self - def with_after(self, after_uuid: UUID): + def with_after(self, after_uuid: UUID) -> "GetBuilder": """Can be used to extract all elements by giving the last ID from the previous "page". Requires limit to be set but cannot be combined with any other filters or search. Part of the Cursor API. """ - if not isinstance(after_uuid, UUID.__args__): # __args__ is workaround for python 3.8 + if not isinstance(after_uuid, UUID.__args__): # type: ignore # __args__ is workaround for python 3.8 raise TypeError("after_uuid must be of type UUID (str or uuid.UUID)") self._after = f'after: "{get_valid_uuid(after_uuid)}"' @@ -1607,7 +1607,7 @@ def with_hybrid( vector: Optional[List[float]] = None, properties: Optional[List[str]] = None, fusion_type: Optional[HybridFusion] = None, - ): + ) -> "GetBuilder": """Get objects using bm25 and vector, then combine the results using a reciprocal ranking algorithm. Parameters @@ -1709,7 +1709,7 @@ def with_generate( def with_alias( self, alias: str, - ): + ) -> "GetBuilder": """Gives an alias for the query. Needs to be used if 'multi_get' requests the same 'class_name' twice. Parameters @@ -1721,7 +1721,7 @@ def with_alias( self._alias = alias return self - def with_consistency_level(self, consistency_level: ConsistencyLevel): + def with_consistency_level(self, consistency_level: ConsistencyLevel) -> "GetBuilder": """Set the consistency level for the request.""" self._consistency_level = f"consistencyLevel: {consistency_level.value} " @@ -1831,13 +1831,13 @@ def do(self) -> dict: ) # no ref props as strings ) if grpc_enabled: - metadata = () + metadata: Union[Tuple, Tuple[Tuple[Literal["authorization"], str]]] = () access_token = self._connection.get_current_bearer_token() if len(access_token) > 0: metadata = (("authorization", access_token),) try: - res, _ = self._connection.grpc_stub.Search.with_call( + res, _ = self._connection.grpc_stub.Search.with_call( # type: ignore weaviate_pb2.SearchRequest( class_name=self._class_name, limit=self._limit, @@ -1894,7 +1894,9 @@ def do(self) -> dict: obj["_additional"] = additional objects.append(obj) - results = {"data": {"Get": {self._class_name: objects}}} + results: Union[Dict[str, Dict[str, Dict[str, List]]], Dict[str, List]] = { + "data": {"Get": {self._class_name: objects}} + } except grpc.RpcError as e: results = {"errors": [e.details()]} @@ -1905,7 +1907,7 @@ def do(self) -> dict: def _extract_additional_properties( self, props: "weaviate_pb2.ResultAdditionalProps" ) -> Dict[str, str]: - additional_props = {} + additional_props: Dict[str, Any] = {} if self._additional_dataclass is None: return additional_props @@ -1938,7 +1940,7 @@ def _extract_additional_properties( def _convert_references_to_grpc_result( self, properties: "weaviate_pb2.ResultProperties" ) -> Dict: - result = {} + result: Dict[str, Any] = {} for name, non_ref_prop in properties.non_ref_properties.items(): result[name] = non_ref_prop diff --git a/weaviate/gql/query.py b/weaviate/gql/query.py index 159553332..a17ca4652 100644 --- a/weaviate/gql/query.py +++ b/weaviate/gql/query.py @@ -169,4 +169,6 @@ def raw(self, gql_query: str) -> Dict[str, Any]: except RequestsConnectionError as conn_err: raise RequestsConnectionError("Query not executed.") from conn_err - return _decode_json_response_dict(response, "GQL query failed") + res = _decode_json_response_dict(response, "GQL query failed") + assert res is not None + return res diff --git a/weaviate/schema/crud_schema.py b/weaviate/schema/crud_schema.py index dd32a2f21..34d3760d5 100644 --- a/weaviate/schema/crud_schema.py +++ b/weaviate/schema/crud_schema.py @@ -3,7 +3,7 @@ """ from dataclasses import dataclass from enum import Enum -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, Dict, cast from requests.exceptions import ConnectionError as RequestsConnectionError @@ -84,7 +84,7 @@ def _to_weaviate_object(self) -> Dict[str, str]: def _from_weaviate_object(cls, weaviate_object: Dict[str, str]) -> "Tenant": return cls( name=weaviate_object["name"], - activity_status=weaviate_object.get("activityStatus", "HOT"), + activity_status=TenantActivityStatus(weaviate_object.get("activityStatus", "HOT")), ) @@ -449,7 +449,7 @@ def update_config(self, class_name: str, config: dict) -> None: if response.status_code != 200: raise UnexpectedStatusCodeException("Update class schema configuration", response) - def get(self, class_name: str = None) -> dict: + def get(self, class_name: Optional[str] = None) -> dict: """ Get the schema from Weaviate. @@ -554,7 +554,9 @@ def get(self, class_name: str = None) -> dict: except RequestsConnectionError as conn_err: raise RequestsConnectionError("Schema could not be retrieved.") from conn_err - return _decode_json_response_dict(response, "Get schema") + res = _decode_json_response_dict(response, "Get schema") + assert res is not None + return res def get_class_shards(self, class_name: str) -> list: """ @@ -598,7 +600,9 @@ def get_class_shards(self, class_name: str) -> list: "Class shards' status could not be retrieved due to connection error." ) from conn_err - return _decode_json_response_list(response, "Get shards' status") + res = _decode_json_response_list(response, "Get shards' status") + assert res is not None + return res def update_class_shard( self, @@ -698,7 +702,7 @@ def update_class_shard( if shard_name is None: return to_return - return to_return[0] + return cast(list, to_return[0]) def _create_complex_properties_from_class(self, schema_class: dict) -> None: """ @@ -908,6 +912,7 @@ def get_class_tenants(self, class_name: str) -> List[Tenant]: raise RequestsConnectionError("Could not get class tenants.") from conn_err tenant_resp = _decode_json_response_list(response, "Get class tenants") + assert tenant_resp is not None return [Tenant._from_weaviate_object(tenant) for tenant in tenant_resp] def update_class_tenants(self, class_name: str, tenants: List[Tenant]) -> None: diff --git a/weaviate/util.py b/weaviate/util.py index c8e08ec4b..49803e355 100644 --- a/weaviate/util.py +++ b/weaviate/util.py @@ -7,7 +7,7 @@ import re from enum import Enum, EnumMeta from io import BufferedReader -from typing import Union, Sequence, Any, Optional, List, Dict, Tuple, TypeVar +from typing import Union, Sequence, Any, Optional, List, Dict, Tuple, cast import requests import uuid as uuid_lib @@ -239,7 +239,7 @@ def _get_dict_from_object(object_: Union[str, dict]) -> dict: # Object is URL response = requests.get(object_) if response.status_code == 200: - return response.json() + return cast(dict, response.json()) raise ValueError("Could not download file " + object_) if not os.path.isfile(object_): @@ -247,7 +247,7 @@ def _get_dict_from_object(object_: Union[str, dict]) -> dict: raise ValueError("No file found at location " + object_) # Object is file with open(object_, "r") as file: - return json.load(file) + return cast(dict, json.load(file)) raise TypeError( "Argument is not of the supported types. Supported types are " "url or file path as string or schema as dict." @@ -582,10 +582,9 @@ def check_batch_result( print(result["result"]["errors"]) -N = TypeVar("N") - - -def _check_positive_num(value: N, arg_name: str, data_type: N, include_zero: bool = False) -> None: +def _check_positive_num( + value: Any, arg_name: str, data_type: type, include_zero: bool = False +) -> None: """ Check if the `value` of the `arg_name` is a positive number. @@ -611,10 +610,10 @@ def _check_positive_num(value: N, arg_name: str, data_type: N, include_zero: boo if not isinstance(value, data_type) or isinstance(value, bool): raise TypeError(f"'{arg_name}' must be of type {data_type}.") if include_zero: - if value < 0: + if value < 0: # type: ignore raise ValueError(f"'{arg_name}' must be positive, i.e. greater or equal to zero (>=0).") else: - if value <= 0: + if value <= 0: # type: ignore raise ValueError(f"'{arg_name}' must be positive, i.e. greater that zero (>0).") @@ -752,10 +751,13 @@ def _get_valid_timeout_config( If 'timeout_config' is/contains negative number/s. """ - def check_number(num: NUMBERS): + def check_number(num: Union[NUMBERS, Tuple[NUMBERS, NUMBERS], None]) -> bool: return isinstance(num, float) or isinstance(num, int) - if check_number(timeout_config) and not isinstance(timeout_config, bool): + if (isinstance(timeout_config, float) or isinstance(timeout_config, int)) and not isinstance( + timeout_config, bool + ): + assert timeout_config is not None if timeout_config <= 0.0: raise ValueError("'timeout_config' cannot be non-positive number/s!") return timeout_config, timeout_config @@ -788,7 +790,7 @@ def _decode_json_response_dict( if 200 <= response.status_code < 300: try: - json_response = response.json() + json_response = cast(dict, response.json()) return json_response except JSONDecodeError: raise ResponseCannotBeDecodedException(location, response) @@ -805,7 +807,7 @@ def _decode_json_response_list( if 200 <= response.status_code < 300: try: json_response = response.json() - return json_response + return cast(list, json_response) except JSONDecodeError: raise ResponseCannotBeDecodedException(location, response) raise UnexpectedStatusCodeException(location, response) diff --git a/weaviate/warnings.py b/weaviate/warnings.py index a064078a6..aab6cf6f5 100644 --- a/weaviate/warnings.py +++ b/weaviate/warnings.py @@ -108,7 +108,7 @@ def token_refresh_failed(exc: Exception) -> None: ) @staticmethod - def weaviate_too_old_vs_latest(server_version: str): + def weaviate_too_old_vs_latest(server_version: str) -> None: warnings.warn( message=f"""Dep004: You are connected to Weaviate {server_version}. Please consider upgrading to the latest version. See https://www.weaviate.io/developers/weaviate for details.""", @@ -117,7 +117,7 @@ def weaviate_too_old_vs_latest(server_version: str): ) @staticmethod - def weaviate_client_too_old_vs_latest(client_version: str, latest_version: str): + def weaviate_client_too_old_vs_latest(client_version: str, latest_version: str) -> None: warnings.warn( message=f"""Dep005: You are using weaviate-client version {client_version}. The latest version is {latest_version}. Please consider upgrading to the latest version. See https://weaviate.io/developers/weaviate/client-libraries/python for details.""", @@ -126,7 +126,7 @@ def weaviate_client_too_old_vs_latest(client_version: str, latest_version: str): ) @staticmethod - def use_of_client_batch_will_be_removed_in_next_major_release(): + def use_of_client_batch_will_be_removed_in_next_major_release() -> None: warnings.warn( message="""Dep006: You are using the `client.batch()` method, which will be removed in the next major release. Please instead use the `client.batch.configure()` method to configure your batch and `client.batch` to enter the context manager. From e4e818364ab47dbceead566b5f4e0a0093052acc Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 09:37:15 +0100 Subject: [PATCH 03/18] respond to review --- weaviate/batch/crud_batch.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/weaviate/batch/crud_batch.py b/weaviate/batch/crud_batch.py index fa950fedc..a2da595cd 100644 --- a/weaviate/batch/crud_batch.py +++ b/weaviate/batch/crud_batch.py @@ -16,7 +16,6 @@ Deque, Dict, List, - Literal, Optional, Sequence, Tuple, @@ -289,7 +288,7 @@ def __init__(self, connection: Connection): self._recommended_num_references = self._batch_size self._num_workers = 1 - self._consistency_level: Optional[Literal["ALL", "ONE", "QUORUM"]] = None + self._consistency_level: Optional[ConsistencyLevel] = None # thread pool executor self._executor: Optional[BatchExecutor] = None @@ -485,7 +484,7 @@ def periodic_check() -> None: else: # way too high, stop sending new batches self._recommended_num_objects = 0 - refresh_time: Union[float, int] = 2 + refresh_time: float = 2 except (RequestsHTTPError, ReadTimeout): refresh_time = 0.1 @@ -671,7 +670,7 @@ def _create_data( """ params: Dict[str, str] = {} if self._consistency_level is not None: - params["consistency_level"] = self._consistency_level + params["consistency_level"] = self._consistency_level.value try: timeout_count = connection_count = batch_error_count = 0 @@ -1331,7 +1330,7 @@ def delete_objects( params: Dict[str, str] = {} if self._consistency_level is not None: - params["consistency_level"] = self._consistency_level + params["consistency_level"] = self._consistency_level.value if tenant is not None: params["tenant"] = tenant @@ -1564,11 +1563,11 @@ def dynamic(self, value: bool) -> None: @property def consistency_level(self) -> Union[str, None]: - return self._consistency_level + return self._consistency_level.value if self._consistency_level is not None else None @consistency_level.setter - def consistency_level(self, x: Optional[ConsistencyLevel]) -> None: - self._consistency_level = ConsistencyLevel(x).value if x else None + def consistency_level(self, x: Optional[Union[ConsistencyLevel, str]]) -> None: + self._consistency_level = ConsistencyLevel(x) @property def recommended_num_objects(self) -> Optional[int]: From ac0e74da158e919072d225fd2cdf9731672f4e50 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 09:45:56 +0100 Subject: [PATCH 04/18] ignore errors in grpc files --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index e2ae1923d..47fcc026d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,10 @@ warn_return_any = true warn_unused_ignores = true exclude = ["weaviate_grpc", "docs", "mock_tests", "test", "integration"] +[[tool.mypy.overrides]] +module = "weaviate_grpc.*" +ignore_errors = true + [[tool.mypy.overrides]] module = "grpc.*" ignore_missing_imports = true From 97bf08e2e4bc1a8d87aa274f462b4fb024dd8f4d Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 09:46:22 +0100 Subject: [PATCH 05/18] type: ignore complex implementation --- weaviate/connect/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weaviate/connect/connection.py b/weaviate/connect/connection.py index 3be4a2bc2..b3566d6f9 100644 --- a/weaviate/connect/connection.py +++ b/weaviate/connect/connection.py @@ -8,7 +8,7 @@ import socket import time from threading import Thread, Event -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse import requests @@ -290,7 +290,7 @@ def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth]) -> None: ): # use refresh token when available try: - if "refresh_token" in self._session.token: + if "refresh_token" in cast(OAuth2Session, self._session).token: assert isinstance(self._session, OAuth2Session) self._session.token = self._session.refresh_token( self._session.metadata["token_endpoint"] @@ -301,7 +301,7 @@ def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth]) -> None: # saved credentials assert _auth is not None new_session = _auth.get_auth_session() - self._session.token = new_session.fetch_token() + self._session.token = new_session.fetch_token() # type: ignore except (RequestsHTTPError, ReadTimeout) as exc: # retry again after one second, might be an unstable connection refresh_time = 1 From f1e7e469e5b620038e880864d850564cba663a45 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 09:56:48 +0100 Subject: [PATCH 06/18] add mypy to pre-commit --- .pre-commit-config.yaml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e8085f08d..dcc64b1bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,4 +19,11 @@ repos: 'flake8-bugbear==22.10.27', 'flake8-comprehensions==3.10.1', 'flake8-builtins==2.0.1' - ] \ No newline at end of file + ] + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.5.1" + hooks: + - id: mypy + entry: mypy ./weaviate + pass_filenames: false From 1e9a69ee8f36663532596595d58072d0c2e2f994 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 10:07:04 +0100 Subject: [PATCH 07/18] setup mypy in precommit --- .pre-commit-config.yaml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dcc64b1bb..788422d29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,3 +27,23 @@ repos: - id: mypy entry: mypy ./weaviate pass_filenames: false + additional_dependencies: [ + 'Authlib==1.2.1', + 'certifi==2023.7.22', + 'cffi==1.15.1', + 'charset-normalizer==3.2.0', + 'cryptography==41.0.3', + 'grpcio==1.57.0', + 'grpcio-tools==1.57.0', + 'idna==3.4', + 'mypy-extensions==1.0.0', + 'protobuf==4.24.2', + 'pycparser==2.21', + 'requests==2.31.0', + 'tomli==2.0.1', + 'types-requests==2.31.0.2', + 'types-urllib3==1.26.25.14', + 'typing_extensions==4.7.1', + 'urllib3==2.0.4', + 'validators==0.22.0', + ] \ No newline at end of file From e51ae9ed14ce1da262c464e657b7e9adb01634c2 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 10:39:01 +0100 Subject: [PATCH 08/18] add type-checking job to actions --- .github/workflows/main.yaml | 48 +++++++++++++++++++++++++------------ requirements-mypy.txt | 19 +++++++++++++++ 2 files changed, 52 insertions(+), 15 deletions(-) create mode 100644 requirements-mypy.txt diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index d737f715f..c1ad11e00 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -14,7 +14,7 @@ on: pull_request: jobs: - lint-and_format: + lint-and-format: name: Run linter and formatter runs-on: ubuntu-latest steps: @@ -34,6 +34,24 @@ jobs: python -m build python -m twine check dist/* + type-checking: + name: Run mypy static type checking + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["3.8", "3.9", "3.10", "3.11"] + folder: ["weaviate"] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.version }} + cache: 'pip' # caching pip dependencies + - run: pip install -r requirements-mypy.txt + - name: Run mypy + run: mypy ${{ matrix.folder }} + unit-tests: name: Run Unit Tests runs-on: ubuntu-latest @@ -43,20 +61,20 @@ jobs: version: ["3.8", "3.9", "3.10", "3.11"] folder: ["test", "mock_tests"] steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.version }} - cache: 'pip' # caching pip dependencies - - run: pip install -r requirements-devel.txt - - name: Run unittests - run: pytest --cov -v --cov-report=term-missing --cov=weaviate --cov-report xml:coverage-${{ matrix.folder }}.xml ${{ matrix.folder }} - - name: Archive code coverage results - if: matrix.version == '3.10' && (github.ref_name != 'main') - uses: actions/upload-artifact@v3 - with: - name: coverage-report-${{ matrix.folder }} - path: coverage-${{ matrix.folder }}.xml + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.version }} + cache: 'pip' # caching pip dependencies + - run: pip install -r requirements-devel.txt + - name: Run unittests + run: pytest --cov -v --cov-report=term-missing --cov=weaviate --cov-report xml:coverage-${{ matrix.folder }}.xml ${{ matrix.folder }} + - name: Archive code coverage results + if: matrix.version == '3.10' && (github.ref_name != 'main') + uses: actions/upload-artifact@v3 + with: + name: coverage-report-${{ matrix.folder }} + path: coverage-${{ matrix.folder }}.xml integration-tests: name: Run Integration Tests diff --git a/requirements-mypy.txt b/requirements-mypy.txt new file mode 100644 index 000000000..237c4af75 --- /dev/null +++ b/requirements-mypy.txt @@ -0,0 +1,19 @@ +Authlib==1.2.1 +certifi==2023.7.22 +cffi==1.15.1 +charset-normalizer==3.2.0 +cryptography==41.0.3 +grpcio==1.57.0 +grpcio-tools==1.57.0 +idna==3.4 +mypy==1.5.1 +mypy-extensions==1.0.0 +protobuf==4.24.2 +pycparser==2.21 +requests==2.31.0 +tomli==2.0.1 +types-requests==2.31.0.2 +types-urllib3==1.26.25.14 +typing_extensions==4.7.1 +urllib3==2.0.4 +validators==0.22.0 From ba4a798f4469def173df41373e4595a3ba065903 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 10:39:59 +0100 Subject: [PATCH 09/18] fix syntax error --- .github/workflows/main.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index c1ad11e00..c61bba8eb 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -203,7 +203,7 @@ jobs: build-and-publish: name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI - needs: [integration-tests, unit-tests, lint-and_format, test-package] + needs: [integration-tests, unit-tests, lint-and-format, test-package] runs-on: ubuntu-latest steps: - name: Checkout From 3aa731e4b92b8b37e4eb306877cd1a22400ce584 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 10:43:00 +0100 Subject: [PATCH 10/18] change names of jobs --- .github/workflows/main.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index c61bba8eb..c3973cbb3 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -15,7 +15,7 @@ on: jobs: lint-and-format: - name: Run linter and formatter + name: Run Linter and Formatter runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -35,7 +35,7 @@ jobs: python -m twine check dist/* type-checking: - name: Run mypy static type checking + name: Run Type Checking runs-on: ubuntu-latest strategy: fail-fast: false From 92332b872e075f1fe0f4612a70cc6b7e165be4bc Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 10:44:17 +0100 Subject: [PATCH 11/18] fix breaking tests --- weaviate/batch/crud_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/batch/crud_batch.py b/weaviate/batch/crud_batch.py index a2da595cd..ea4ae5b04 100644 --- a/weaviate/batch/crud_batch.py +++ b/weaviate/batch/crud_batch.py @@ -1567,7 +1567,7 @@ def consistency_level(self) -> Union[str, None]: @consistency_level.setter def consistency_level(self, x: Optional[Union[ConsistencyLevel, str]]) -> None: - self._consistency_level = ConsistencyLevel(x) + self._consistency_level = ConsistencyLevel(x) if x is not None else None @property def recommended_num_objects(self) -> Optional[int]: From f32a847c96a102a60fcd2b37a1300d39d984fceb Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 11:05:23 +0100 Subject: [PATCH 12/18] change type not implementation --- weaviate/connect/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weaviate/connect/connection.py b/weaviate/connect/connection.py index b3566d6f9..01e37c083 100644 --- a/weaviate/connect/connection.py +++ b/weaviate/connect/connection.py @@ -620,13 +620,13 @@ def timeout_config(self, timeout_config: TIMEOUT_TYPE_RETURN) -> None: def proxies(self) -> dict: return self._proxies - def wait_for_weaviate(self, startup_period: Optional[int]) -> None: + def wait_for_weaviate(self, startup_period: int) -> None: """ Waits until weaviate is ready or the timelimit given in 'startup_period' has passed. Parameters ---------- - startup_period : Optional[int] + startup_period : int Describes how long the client will wait for weaviate to start in seconds. Raises @@ -636,7 +636,7 @@ def wait_for_weaviate(self, startup_period: Optional[int]) -> None: """ ready_url = self.url + self._api_version_path + "/.well-known/ready" - for _i in range(startup_period or 30): + for _i in range(startup_period): try: requests.get(ready_url, headers=self._get_request_header()).raise_for_status() return From 937e550c508b3e437eecd5a5375dad41606f1f32 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 11:06:18 +0100 Subject: [PATCH 13/18] fix docstring --- weaviate/gql/filter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/weaviate/gql/filter.py b/weaviate/gql/filter.py index 6d4f78ca7..fc3755ce1 100644 --- a/weaviate/gql/filter.py +++ b/weaviate/gql/filter.py @@ -1123,9 +1123,9 @@ def _check_type(var_name: str, value: Any, dtype: Union[Tuple[type, type], type] ---------- var_name : str The variable name for which to check the type (used for error message)! - value : T + value : Any The value for which to check the type. - dtype : T + dtype : Union[Tuple[type, type], type] The expected data type of the `value`. Raises From dae0dd96ec59ad9228c4973526a4f1db66eba2f1 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 11:16:50 +0100 Subject: [PATCH 14/18] merge requirements files --- requirements-devel.txt | 7 +++++++ requirements-mypy.txt | 19 ------------------- 2 files changed, 7 insertions(+), 19 deletions(-) delete mode 100644 requirements-mypy.txt diff --git a/requirements-devel.txt b/requirements-devel.txt index 594d957ff..48c21b058 100644 --- a/requirements-devel.txt +++ b/requirements-devel.txt @@ -20,6 +20,13 @@ coverage==7.3.0 werkzeug>=2.3.7 pytest-httpserver>=1.0.8 +mypy>=1.5.1<2.0.0 +mypy-extensions==1.0.0 +tomli>=2.0.1<3.0.0 +types-requests>=2.31.0.2<3.0.0 +types-urllib3>=1.26.25.14<2.0.0 +typing_extensions>=4.7.1<5.0.0 + pre-commit flake8 diff --git a/requirements-mypy.txt b/requirements-mypy.txt deleted file mode 100644 index 237c4af75..000000000 --- a/requirements-mypy.txt +++ /dev/null @@ -1,19 +0,0 @@ -Authlib==1.2.1 -certifi==2023.7.22 -cffi==1.15.1 -charset-normalizer==3.2.0 -cryptography==41.0.3 -grpcio==1.57.0 -grpcio-tools==1.57.0 -idna==3.4 -mypy==1.5.1 -mypy-extensions==1.0.0 -protobuf==4.24.2 -pycparser==2.21 -requests==2.31.0 -tomli==2.0.1 -types-requests==2.31.0.2 -types-urllib3==1.26.25.14 -typing_extensions==4.7.1 -urllib3==2.0.4 -validators==0.22.0 From 277d1cd7cb90fd568df08dde73743c4fde49f80c Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 11:17:24 +0100 Subject: [PATCH 15/18] correctly test mypy with differing py vers --- .github/workflows/main.yaml | 6 +++--- weaviate/batch/crud_batch.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index c3973cbb3..a0d27b52b 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -46,11 +46,11 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: ${{ matrix.version }} + python-version: "3.11" cache: 'pip' # caching pip dependencies - - run: pip install -r requirements-mypy.txt + - run: pip install -r requirements-devel.txt - name: Run mypy - run: mypy ${{ matrix.folder }} + run: mypy --python-version ${{matrix.version}} ${{ matrix.folder }} unit-tests: name: Run Unit Tests diff --git a/weaviate/batch/crud_batch.py b/weaviate/batch/crud_batch.py index ea4ae5b04..4d76cba44 100644 --- a/weaviate/batch/crud_batch.py +++ b/weaviate/batch/crud_batch.py @@ -272,7 +272,7 @@ def __init__(self, connection: Connection): # do not keep too many past values, so it is a better estimation of the throughput is computed for 1 second self._objects_throughput_frame: Deque[float] = deque(maxlen=5) self._references_throughput_frame: Deque[float] = deque(maxlen=5) - self._future_pool: List[Future[Tuple[Response | None, int]]] = [] + self._future_pool: List[Future[Tuple[Union[Response, None], int]]] = [] self._reference_batch_queue: List[ReferenceBatchRequest] = [] self._callback_lock = threading.Lock() From ced06f8e62de691821cde5d9de617b39441c6ab2 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 14:03:46 +0100 Subject: [PATCH 16/18] only specify stubs in pre-commit mypy config --- .pre-commit-config.yaml | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 788422d29..1a72aa038 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,22 +28,9 @@ repos: entry: mypy ./weaviate pass_filenames: false additional_dependencies: [ - 'Authlib==1.2.1', - 'certifi==2023.7.22', - 'cffi==1.15.1', - 'charset-normalizer==3.2.0', - 'cryptography==41.0.3', - 'grpcio==1.57.0', - 'grpcio-tools==1.57.0', - 'idna==3.4', 'mypy-extensions==1.0.0', - 'protobuf==4.24.2', - 'pycparser==2.21', - 'requests==2.31.0', 'tomli==2.0.1', 'types-requests==2.31.0.2', 'types-urllib3==1.26.25.14', 'typing_extensions==4.7.1', - 'urllib3==2.0.4', - 'validators==0.22.0', ] \ No newline at end of file From 0ba525b7a50abd9b11535276b25ea686ded5a71c Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 15:05:25 +0100 Subject: [PATCH 17/18] remove more dedundant stubs --- .pre-commit-config.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a72aa038..27d8b5cf0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,8 +28,6 @@ repos: entry: mypy ./weaviate pass_filenames: false additional_dependencies: [ - 'mypy-extensions==1.0.0', - 'tomli==2.0.1', 'types-requests==2.31.0.2', 'types-urllib3==1.26.25.14', 'typing_extensions==4.7.1', From 41533c88894224c0d3cc6c91ee603f1f48a92718 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Wed, 6 Sep 2023 15:52:51 +0100 Subject: [PATCH 18/18] update job needs --- .github/workflows/main.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index a0d27b52b..89d48b71f 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -203,7 +203,7 @@ jobs: build-and-publish: name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI - needs: [integration-tests, unit-tests, lint-and-format, test-package] + needs: [integration-tests, unit-tests, lint-and-format, type-checking, test-package] runs-on: ubuntu-latest steps: - name: Checkout