diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 1e8081a2a..12b7e8f0f 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -114,65 +114,6 @@ jobs: name: coverage-report-integration-embedded path: coverage-integration-embedded.xml - integration-tests-v3: - name: Run Integration Tests v3 - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - versions: [ - { py: "3.9", weaviate: $WEAVIATE_126}, - { py: "3.10", weaviate: $WEAVIATE_126}, - { py: "3.11", weaviate: $WEAVIATE_123}, - { py: "3.11", weaviate: $WEAVIATE_124}, - { py: "3.11", weaviate: $WEAVIATE_125}, - { py: "3.11", weaviate: $WEAVIATE_126}, - { py: "3.12", weaviate: $WEAVIATE_127} - ] - optional_dependencies: [false] - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - fetch-tags: true - - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.versions.py }} - cache: 'pip' # caching pip dependencies - - name: Login to Docker Hub - uses: docker/login-action@v3 - if: ${{ !github.event.pull_request.head.repo.fork && github.triggering_actor != 'dependabot[bot]' }} - with: - username: ${{secrets.DOCKER_USERNAME}} - password: ${{secrets.DOCKER_PASSWORD}} - - run: | - pip install -r requirements-devel.txt - pip install . - - name: free space - run: sudo rm -rf /usr/local/lib/android - - name: start weaviate - run: /bin/bash ci/start_weaviate.sh ${{ matrix.versions.weaviate }} - - name: Run integration tests with auth secrets - if: ${{ !github.event.pull_request.head.repo.fork }} - env: - AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} - OKTA_CLIENT_SECRET: ${{ secrets.OKTA_CLIENT_SECRET }} - WCS_DUMMY_CI_PW: ${{ secrets.WCS_DUMMY_CI_PW }} - OKTA_DUMMY_CI_PW: ${{ secrets.OKTA_DUMMY_CI_PW }} -# OPENAI_APIKEY: ${{ secrets.OPENAI_APIKEY }} disabled until we have a working key - run: pytest -v --cov --cov-report=term-missing --cov=weaviate --cov-report xml:coverage-integration-v3.xml integration_v3 - - name: Run integration tests without auth secrets (for forks) - if: ${{ github.event.pull_request.head.repo.fork }} - run: pytest -v --cov --cov-report=term-missing --cov=weaviate --cov-report xml:coverage-integration-v3.xml integration_v3 - - name: Archive code coverage results - if: matrix.versions.py == '3.10' && (github.ref_name != 'main') - uses: actions/upload-artifact@v4 - with: - name: coverage-report-integration-v3 - path: coverage-integration-v3.xml - - name: stop weaviate - run: /bin/bash ci/stop_weaviate.sh ${{ matrix.versions.weaviate }} - integration-tests: name: Run Integration Tests runs-on: ubuntu-latest @@ -219,7 +160,7 @@ jobs: WCS_DUMMY_CI_PW: ${{ secrets.WCS_DUMMY_CI_PW }} OKTA_DUMMY_CI_PW: ${{ secrets.OKTA_DUMMY_CI_PW }} # OPENAI_APIKEY: ${{ secrets.OPENAI_APIKEY }} disabled until we have a working key - run: pytest -n auto -v --cov --cov-report=term-missing --cov=weaviate --cov-report xml:coverage-integration.xml integration + run: pytest -n auto --dist loadgroup -v --cov --cov-report=term-missing --cov=weaviate --cov-report xml:coverage-integration.xml integration - name: Run integration tests without auth secrets (for forks) if: ${{ github.event.pull_request.head.repo.fork }} run: pytest -v --cov --cov-report=term-missing --cov=weaviate --cov-report xml:coverage-integration.xml integration @@ -269,7 +210,7 @@ jobs: run: /bin/bash ci/stop_weaviate.sh ${{ matrix.versions.weaviate }} Codecov: - needs: [Unit-Tests, Integration-Tests, Integration-Tests-v3] + needs: [Unit-Tests, Integration-Tests] runs-on: ubuntu-latest if: github.ref_name != 'main' steps: @@ -370,7 +311,7 @@ jobs: OKTA_CLIENT_SECRET: ${{ secrets.OKTA_CLIENT_SECRET }} WCS_DUMMY_CI_PW: ${{ secrets.WCS_DUMMY_CI_PW }} OKTA_DUMMY_CI_PW: ${{ secrets.OKTA_DUMMY_CI_PW }} - run: pytest -v -n auto integration + run: pytest -v -n auto --dist loadgroup integration - name: Run integration tests without auth secrets (for forks) if: ${{ github.event.pull_request.head.repo.fork }} run: pytest -v -n auto integration diff --git a/integration_v3/test_backup_v4.py b/integration/test_backup_v4.py similarity index 99% rename from integration_v3/test_backup_v4.py rename to integration/test_backup_v4.py index bf7c3785d..f2338b2bf 100644 --- a/integration_v3/test_backup_v4.py +++ b/integration/test_backup_v4.py @@ -19,6 +19,8 @@ BackupFailedException, ) +pytestmark = pytest.mark.xdist_group(name="backup") + BACKEND = BackupStorage.FILESYSTEM PARAGRAPHS_IDS = [ diff --git a/integration_v3/__init__.py b/integration_v3/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/integration_v3/people_schema.json b/integration_v3/people_schema.json deleted file mode 100644 index 617d322cd..000000000 --- a/integration_v3/people_schema.json +++ /dev/null @@ -1,133 +0,0 @@ -{ - "classes": [ - { - "class": "Call", - "description": "A call between two Persons", - "properties": [ - { - "name": "start", - "description": "The start of the call", - "dataType": [ - "date" - ] - }, - { - "name": "caller", - "description": "The person who calls", - "dataType": [ - "Person" - ] - }, - { - "name": "recipient", - "description": "The person who was called", - "dataType": [ - "Person" - ] - } - ] - }, - { - "class": "Exercise", - "description": "A sport exercise", - "properties": [ - { - "name": "name", - "description": "The name of the exercise", - "dataType": [ - "string" - ] - } - ] - }, - { - "class": "Person", - "description": "A person such as humans or personality known through culture", - "properties": [ - { - "name": "name", - "description": "The name of this person", - "dataType": [ - "string" - ] - } - ] - }, - { - "class": "Group", - "description": "A set of persons who are associated with each other over some common properties", - "properties": [ - { - "name": "name", - "description": "The name under which this group is known", - "dataType": [ - "text" - ], - "indexInverted": true, - "moduleConfig": { - "text2vec-contextionary": { - "skip": false, - "vectorizePropertyName": false - } - } - }, - { - "name": "members", - "description": "The persons that are part of this group", - "dataType": [ - "Person" - ] - } - ] - }, - { - "class": "Association", - "description": "A legal body that consists of a set of juristic persons. All juristic persons in a group are deliberate members.", - "moduleConfig": - { - "text2vec-contextionary": - { - "vectorizeClassName": false - } - }, - "properties": [ - { - "name": "name", - "description": "The name of the association", - "dataType": [ - "text" - ] - }, - { - "name": "members", - "description": "The juristic persons that are part of this association", - "dataType": [ - "Person", "Association" - ], - "moduleConfig": - { - "text2vec-contextionary": - { - "vectorizePropertyName": true - } - } - }, - { - "name": "registrationId", - "description": "The id under which the association is legally registered", - "dataType": [ - "string" - ], - "moduleConfig": - { - "text2vec-contextionary": - { - "vectorizePropertyName": false - } - }, - "indexInverted": false - } - ] - } - ] -} diff --git a/integration_v3/test_authentication.py b/integration_v3/test_authentication.py deleted file mode 100644 index 7ac155a14..000000000 --- a/integration_v3/test_authentication.py +++ /dev/null @@ -1,268 +0,0 @@ -import os -import time -import warnings -from typing import Optional, Dict - -import pytest -import requests -from requests.exceptions import ConnectionError as RequestsConnectionError -from requests.exceptions import HTTPError as RequestsHTTPError - -import weaviate -from weaviate import ( - AuthenticationFailedException, - AuthClientCredentials, - AuthClientPassword, - AuthBearerToken, -) -from weaviate.auth import AuthApiKey -from weaviate.exceptions import UnexpectedStatusCodeException - -ANON_PORT = 8080 -AZURE_PORT = 8081 -OKTA_PORT_CC = 8082 -OKTA_PORT_USERS = 8083 -WCS_PORT = 8085 - - -def wait_for_weaviate(url: str) -> None: - ready_url = url + "/v1/.well-known/ready" - for _i in range(5): - try: - requests.get(ready_url).raise_for_status() - return - except (RequestsHTTPError, RequestsConnectionError): - time.sleep(1) - - try: - requests.get(ready_url).raise_for_status() - return - except (RequestsHTTPError, RequestsConnectionError): - pass - raise ValueError - - -def is_auth_enabled(url: str): - wait_for_weaviate(url) - response = requests.get(url + "/v1/.well-known/openid-configuration") - return response.status_code == 200 - - -def test_no_auth_provided(): - """Test exception when trying to access a weaviate that requires authentication.""" - url = f"http://localhost:{AZURE_PORT}" - assert is_auth_enabled(url) - with pytest.raises(AuthenticationFailedException): - weaviate.Client(url) - - -@pytest.mark.parametrize( - "name,env_variable_name,port,scope", - [ - # ("azure", "AZURE_CLIENT_SECRET", AZURE_PORT, None), expired - # ( - # "azure", - # "AZURE_CLIENT_SECRET", - # AZURE_PORT, - # "4706508f-30c2-469b-8b12-ad272b3de864/.default", - # ), - ("okta", "OKTA_CLIENT_SECRET", OKTA_PORT_CC, "some_scope"), - ], -) -def test_authentication_client_credentials( - name: str, env_variable_name: str, port: str, scope: Optional[str] -): - """Test client credential flow with various providers.""" - client_secret = os.environ.get(env_variable_name) - if client_secret is None: - pytest.skip(f"No {name} login data found.") - - url = f"http://localhost:{port}" - assert is_auth_enabled(url) - client = weaviate.Client( - url, - auth_client_secret=AuthClientCredentials(client_secret=client_secret, scope=scope), - ) - client.schema.delete_all() # no exception - - -@pytest.mark.parametrize( - "name,user,env_variable_name,port,scope,warning", - [ - # ( # WCS keycloak times out too often - # "WCS", - # "ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net", - # "WCS_DUMMY_CI_PW", - # WCS_PORT, - # None, - # False, - # ), - ( - "okta", - "test@test.de", - "OKTA_DUMMY_CI_PW", - OKTA_PORT_USERS, - "some_scope offline_access", - False, - ), - ( - "okta - no refresh", - "test@test.de", - "OKTA_DUMMY_CI_PW", - OKTA_PORT_USERS, - "some_scope", - True, - ), - ], -) -def test_authentication_user_pw( - recwarn, name: str, user: str, env_variable_name: str, port: str, scope: str, warning: bool -): - """Test authentication using Resource Owner Password Credentials Grant (User + PW).""" - # testing for warnings can be flaky without this as there are open SSL conections - warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) - warnings.filterwarnings(action="ignore", message="Dep005", category=DeprecationWarning) - - url = f"http://localhost:{port}" - assert is_auth_enabled(url) - - pw = os.environ.get(env_variable_name) - if pw is None: - pytest.skip(f"No login data for {name} found.") - - if scope is not None: - auth = AuthClientPassword(username=user, password=pw, scope=scope) - else: - auth = AuthClientPassword(username=user, password=pw) - - client = weaviate.Client(url, auth_client_secret=auth) - client.schema.delete_all() # no exception - if warning: - assert any([str(w.message).startswith("Auth002") for w in recwarn]) - else: - assert not any([str(w.message).startswith("Auth002") for w in recwarn]) - - -def _get_access_token(url: str, user: str, pw: str) -> Dict[str, str]: - # get an access token with WCS user and pw. - weaviate_open_id_config = requests.get(url + "/v1/.well-known/openid-configuration") - response_json = weaviate_open_id_config.json() - client_id = response_json["clientId"] - href = response_json["href"] - - # Get the token issuer's OIDC configuration - response_auth = requests.get(href) - - # Construct the POST request to send to 'token_endpoint' - auth_body = { - "grant_type": "password", - "client_id": client_id, - "username": user, - "password": pw, - "scope": "openid offline_access", - } - response_post = requests.post(response_auth.json()["token_endpoint"], auth_body) - return response_post.json() - - -@pytest.mark.parametrize( - "name,user,env_variable_name,port", - [ - # ( - # "WCS", - # "ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net", - # "WCS_DUMMY_CI_PW", - # WCS_PORT, - # ), - ( - "okta", - "test@test.de", - "OKTA_DUMMY_CI_PW", - OKTA_PORT_USERS, - ), - ], -) -def test_authentication_with_bearer_token(name: str, user: str, env_variable_name: str, port: str): - """Test authentication using existing bearer token.""" - url = f"http://localhost:{port}" - assert is_auth_enabled(url) - pw = os.environ.get(env_variable_name) - if pw is None: - pytest.skip(f"No login data for {name} found.") - - # use token to authenticate - token = _get_access_token(url, user, pw) - - client = weaviate.Client( - url, - auth_client_secret=AuthBearerToken( - access_token=token["access_token"], - expires_in=int(token["expires_in"]), - refresh_token=token["refresh_token"], - ), - ) - client.schema.delete_all() # no exception - - -def test_client_with_authentication_with_anon_weaviate(recwarn): - """Test that we warn users when their client has auth enabled, but weaviate has only anon access.""" - # testing for warnings can be flaky without this as there are open SSL conections - warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) - warnings.filterwarnings(action="ignore", message="Dep005", category=DeprecationWarning) - - url = f"http://localhost:{ANON_PORT}" - assert not is_auth_enabled(url) - - client = weaviate.Client( - url, - auth_client_secret=AuthClientPassword(username="someUser", password="SomePw"), - ) - - # only one warning - assert any([str(w.message).startswith("Auth001") for w in recwarn]) - - client.schema.delete_all() # no exception, client works - - -def test_bearer_token_without_refresh(recwarn): - """Test that the client warns users when only supplying an access token without refresh.""" - - pytest.skip("WCS keycloak times out too often") - - # testing for warnings can be flaky without this as there are open SSL conections - warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) - warnings.filterwarnings(action="ignore", message="Dep005", category=DeprecationWarning) - - url = f"http://localhost:{WCS_PORT}" - assert is_auth_enabled(url) - pw = os.environ.get("WCS_DUMMY_CI_PW") - if pw is None: - pytest.skip("No login data for WCS found.") - - token = _get_access_token(url, "ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net", pw) - client = weaviate.Client( - url, - auth_client_secret=AuthBearerToken( - access_token=token["access_token"], - ), - ) - client.schema.delete_all() # no exception, client works - - assert any([str(w.message).startswith("Auth002") for w in recwarn]) - - -def test_api_key(): - url = f"http://localhost:{WCS_PORT}" - assert is_auth_enabled(url) - - client = weaviate.Client(url, auth_client_secret=AuthApiKey(api_key="my-secret-key")) - client.schema.delete_all() # no exception, client works - - -def test_api_key_wrong_key(): - url = f"http://localhost:{WCS_PORT}" - assert is_auth_enabled(url) - - with pytest.raises(UnexpectedStatusCodeException) as e: - weaviate.Client(url, auth_client_secret=AuthApiKey(api_key="wrong_key")) - assert e.value.status_code == 401 diff --git a/integration_v3/test_backup.py b/integration_v3/test_backup.py deleted file mode 100644 index 1c6aa1501..000000000 --- a/integration_v3/test_backup.py +++ /dev/null @@ -1,336 +0,0 @@ -import datetime -import time -from typing import Dict, Any, List - -import pytest - -import weaviate -from weaviate.exceptions import UnexpectedStatusCodeException, BackupFailedException - -BACKUP_FILESYSTEM_PATH = "/tmp/backups" # must be the same location as in the docker-compose file -BACKEND = "filesystem" - -schema = { - "classes": [ - { - "class": "Paragraph", - "properties": [ - {"dataType": ["text"], "name": "contents"}, - {"dataType": ["Paragraph"], "name": "hasParagraphs"}, - ], - }, - { - "class": "Article", - "properties": [ - {"dataType": ["string"], "name": "title"}, - {"dataType": ["Paragraph"], "name": "hasParagraphs"}, - {"dataType": ["date"], "name": "datePublished"}, - ], - }, - ] -} - -paragraphs = [ - {"id": "fd34ccf4-1a2a-47ad-8446-231839366c3f", "properties": {"contents": "paragraph 1"}}, - {"id": "2653442b-05d8-4fa3-b46a-d4a152eb63bc", "properties": {"contents": "paragraph 2"}}, - {"id": "55374edb-17de-487f-86cb-9a9fbc30823f", "properties": {"contents": "paragraph 3"}}, - {"id": "124ff6aa-597f-44d0-8c13-62fbb1e66888", "properties": {"contents": "paragraph 4"}}, - {"id": "f787386e-7d1c-481f-b8c3-3dbfd8bbad85", "properties": {"contents": "paragraph 5"}}, -] - -articles = [ - { - "id": "2fd68cbc-21ff-4e19-aaef-62531dade974", - "properties": { - "title": "article a", - "datePublished": datetime.datetime.now(datetime.timezone.utc).isoformat(), - }, - }, - { - "id": "7ea3f7b8-65fd-4318-a842-ae9ba38ffdca", - "properties": { - "title": "article b", - "datePublished": datetime.datetime.now(datetime.timezone.utc).isoformat(), - }, - }, - { - "id": "769a4280-4b85-4e67-b685-07796c49a764", - "properties": { - "title": "article c", - "datePublished": datetime.datetime.now(datetime.timezone.utc).isoformat(), - }, - }, - { - "id": "97fcc234-fd16-4a40-82bb-d614e9bddf8b", - "properties": { - "title": "article d", - "datePublished": datetime.datetime.now(datetime.timezone.utc).isoformat(), - }, - }, - { - "id": "3fa435d3-6ab2-489d-abed-c25ec526c9f4", - "properties": { - "title": "article e", - "datePublished": datetime.datetime.now(datetime.timezone.utc).isoformat(), - }, - }, -] - - -@pytest.fixture(scope="module") -def client(): - client = weaviate.Client("http://localhost:8080") - client.schema.create(schema) - for para in paragraphs: - client.data_object.create(para["properties"], "Paragraph", para["id"]) - for i, art in enumerate(articles): - client.data_object.create(art["properties"], "Article", art["id"]) - client.data_object.reference.add( - from_uuid=art["id"], - from_class_name="Article", - from_property_name="hasParagraphs", - to_uuid=paragraphs[i]["id"], - to_class_name="Paragraph", - ) - yield client - client.schema.delete_all() - - -def _assert_objects_exist(local_client: weaviate.Client, class_name: str, expected_count: int): - result = local_client.query.aggregate(class_name).with_meta_count().do() - count = result["data"]["Aggregate"][class_name][0]["meta"]["count"] - assert ( - expected_count == count - ), f"{class_name}: expected count: {expected_count}, received: {count}" - - -def _create_backup_id() -> str: - return str(round(time.time() * 1000)) - - -def _check_response( - response: Dict[str, Any], backup_id: str, status: List[str], classes_include: List[str] = None -) -> None: - assert response["id"] == backup_id - if classes_include is not None: - assert len(response["classes"]) == len(classes_include) - assert sorted(response["classes"]) == sorted(classes_include) - assert response["backend"] == "filesystem" - assert response["path"] == f"{BACKUP_FILESYSTEM_PATH}/{backup_id}" - assert response["status"] in status - - -def test_create_and_restore_backup_with_waiting(client, tmp_path) -> None: - """Create and restore backup with waiting.""" - backup_id = _create_backup_id() - # check data exists - _assert_objects_exist(client, "Article", len(articles)) - _assert_objects_exist(client, "Paragraph", len(paragraphs)) - - # create backup - classes = ["Article", "Paragraph"] - resp = client.backup.create(backup_id=backup_id, backend=BACKEND, wait_for_completion=True) - _check_response(resp, backup_id, ["SUCCESS"], classes) - - # check data still exists - _assert_objects_exist(client, "Article", len(articles)) - _assert_objects_exist(client, "Paragraph", len(paragraphs)) - - # check create status - resp = client.backup.get_create_status(backup_id, BACKEND) - _check_response(resp, backup_id, ["SUCCESS"]) - - # remove existing class - client.schema.delete_class("Article") - client.schema.delete_class("Paragraph") - # restore backup - resp = client.backup.restore(backup_id=backup_id, backend=BACKEND, wait_for_completion=True) - _check_response(resp, backup_id, ["SUCCESS"], classes) - - # check data exists again - _assert_objects_exist(client, "Article", len(articles)) - _assert_objects_exist(client, "Paragraph", len(paragraphs)) - # check restore status - resp = client.backup.get_restore_status(backup_id, BACKEND) - _check_response(resp, backup_id, ["SUCCESS"]) - - -def test_create_and_restore_backup_without_waiting(client: weaviate.Client) -> None: - """Create and restore backup without waiting.""" - backup_id = _create_backup_id() - - # check data exists - _assert_objects_exist(client, "Article", len(articles)) - # create backup - include = ["Article"] - - resp = client.backup.create(backup_id=backup_id, include_classes=include, backend=BACKEND) - _check_response(resp, backup_id, ["STARTED"], include) - - # wait until created - while True: - resp = client.backup.get_create_status(backup_id, BACKEND) - _check_response(resp, backup_id, ["STARTED", "TRANSFERRING", "TRANSFERRED", "SUCCESS"]) - if resp["status"] == "SUCCESS": - break - time.sleep(0.1) - # check data still exists - _assert_objects_exist(client, "Article", len(articles)) - # remove existing class - client.schema.delete_class("Article") - # restore backup - resp = client.backup.restore( - backup_id=backup_id, - include_classes=include, - backend=BACKEND, - ) - _check_response(resp, backup_id, ["STARTED"], include) - # wait until restored - while True: - resp = client.backup.get_restore_status(backup_id, BACKEND) - _check_response(resp, backup_id, ["STARTED", "TRANSFERRING", "TRANSFERRED", "SUCCESS"]) - if resp["status"] == "SUCCESS": - break - time.sleep(0.1) - # check data exists again - _assert_objects_exist(client, "Article", len(articles)) - - -def test_create_and_restore_1_of_2_classes(client: weaviate.Client) -> None: - """Create and restore 1 of 2 classes.""" - backup_id = _create_backup_id() - - # check data exists - _assert_objects_exist(client, "Article", len(articles)) - - # create backup - include = ["Article"] - resp = client.backup.create( - backup_id=backup_id, include_classes=include, backend=BACKEND, wait_for_completion=True - ) - _check_response(resp, backup_id, ["SUCCESS"], include) - - # check data still exists - _assert_objects_exist(client, "Article", len(articles)) - # check create status - resp = client.backup.get_create_status(backup_id, BACKEND) - _check_response(resp, backup_id, ["SUCCESS"]) - - # remove existing class - client.schema.delete_class("Article") - # restore backup - resp = client.backup.restore( - backup_id=backup_id, include_classes=include, backend=BACKEND, wait_for_completion=True - ) - _check_response(resp, backup_id, ["SUCCESS"], include) - - # check data exists again - _assert_objects_exist(client, "Article", len(articles)) - # check restore status - resp = client.backup.get_restore_status(backup_id, BACKEND) - _check_response(resp, backup_id, ["SUCCESS"]) - - -def test_fail_on_non_existing_backend(client: weaviate.Client) -> None: - """Fail backup functions on non-existing backend""" - backup_id = _create_backup_id() - backend = "non-existing-backend" - for func in [client.backup.create, client.backup.get_create_status, client.backup.restore]: - with pytest.raises(ValueError) as excinfo: - func(backup_id=backup_id, backend=backend) - assert backend in str(excinfo.value) - - -def test_fail_on_non_existing_class(client: weaviate.Client) -> None: - """Fail backup functions on non-existing class""" - backup_id = _create_backup_id() - class_name = "NonExistingClass" - for func in [client.backup.create, client.backup.restore]: - with pytest.raises(UnexpectedStatusCodeException) as excinfo: - func(backup_id=backup_id, backend=BACKEND, include_classes=class_name) - assert class_name in str(excinfo.value) - assert "422" in str(excinfo.value) - - -def test_fail_restoring_backup_for_existing_class(client: weaviate.Client): - """Fail restoring backup for existing class.""" - backup_id = _create_backup_id() - class_name = ["Article"] - resp = client.backup.create( - backup_id=backup_id, include_classes=class_name, backend=BACKEND, wait_for_completion=True - ) - _check_response(resp, backup_id, ["SUCCESS"], class_name) - - # fail restore - with pytest.raises(BackupFailedException) as excinfo: - client.backup.restore( - backup_id=backup_id, - include_classes=class_name, - backend=BACKEND, - wait_for_completion=True, - ) - assert class_name[0] in str(excinfo.value) - assert "already exists" in str(excinfo.value) - - -def test_fail_creating_existing_backup(client: weaviate.Client): - """Fail creating existing backup""" - backup_id = _create_backup_id() - class_name = ["Article"] - resp = client.backup.create( - backup_id=backup_id, include_classes=class_name, backend=BACKEND, wait_for_completion=True - ) - _check_response(resp, backup_id, ["SUCCESS"], class_name) - - # fail create - with pytest.raises(UnexpectedStatusCodeException) as excinfo: - client.backup.create( - backup_id=backup_id, - include_classes=class_name, - backend=BACKEND, - wait_for_completion=True, - ) - assert backup_id in str(excinfo.value) - assert "422" in str(excinfo.value) - - -def test_fail_restoring_non_existing_backup(client: weaviate.Client): - """fail restoring non-existing backup""" - backup_id = _create_backup_id() - with pytest.raises(UnexpectedStatusCodeException) as excinfo: - client.backup.restore(backup_id=backup_id, backend=BACKEND, wait_for_completion=True) - assert backup_id in str(excinfo.value) - assert "404" in str(excinfo.value) - - -def test_fail_checking_status_for_non_existing_restore(client: weaviate.Client): - """Fail checking restore status for non-existing restore.""" - backup_id = _create_backup_id() - for func in [client.backup.get_restore_status, client.backup.get_create_status]: - with pytest.raises(UnexpectedStatusCodeException) as excinfo: - func( - backup_id=backup_id, - backend=BACKEND, - ) - assert backup_id in str(excinfo) - assert "404" in str(excinfo) - - -def test_fail_creating_backup_for_both_include_and_exclude_classes(client: weaviate.Client): - """fail creating backup for both include and exclude classes""" - backup_id = _create_backup_id() - - for func in [client.backup.create, client.backup.restore]: - with pytest.raises(TypeError) as excinfo: - include = "Article" - exclude = "Paragraph" - func( - backup_id=backup_id, - include_classes=include, - exclude_classes=exclude, - backend=BACKEND, - wait_for_completion=True, - ) - assert "Either 'include_classes' OR 'exclude_classes' can be set, not both" in str( - excinfo.value - ) diff --git a/integration_v3/test_batch.py b/integration_v3/test_batch.py deleted file mode 100644 index 58511c7d2..000000000 --- a/integration_v3/test_batch.py +++ /dev/null @@ -1,652 +0,0 @@ -import uuid -from dataclasses import dataclass -from typing import List, Union, Sequence, Optional - -import pytest - -import weaviate -from weaviate import Shard, Tenant -from weaviate.gql.filter import VALUE_ARRAY_TYPES, WHERE_OPERATORS - -UUID = Union[str, uuid.UUID] - - -def has_batch_errors(results: dict) -> bool: - """ - Check batch results for errors. - - Parameters - ---------- - results : dict - The Weaviate batch creation return value. - """ - - if results is not None: - for result in results: - if "result" in result and "errors" in result["result"]: - if "error" in result["result"]["errors"]: - return True - return False - - -@dataclass -class MockNumpyTorch: - array: list - - def squeeze(self) -> "MockNumpyTorch": - return self - - def tolist(self) -> list: - return self.array - - -@dataclass -class MockTensorFlow: - array: list - - def numpy(self) -> "MockNumpyTorch": - return MockNumpyTorch(self.array) - - -@pytest.fixture(scope="function") -def client(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create_class( - { - "class": "Test", - "properties": [ - {"name": "test", "dataType": ["Test"]}, - {"name": "name", "dataType": ["string"]}, - {"name": "names", "dataType": ["string[]"]}, - ], - "vectorizer": "none", - } - ) - yield client - client.schema.delete_all() - - -@pytest.mark.parametrize( - "vector", - [None, [1, 2, 3], MockNumpyTorch([1, 2, 3]), MockTensorFlow([1, 2, 3])], -) -@pytest.mark.parametrize("uuid", [None, uuid.uuid4(), str(uuid.uuid4()), uuid.uuid4().hex]) -def test_add_data_object(client: weaviate.Client, uuid: Optional[UUID], vector: Optional[Sequence]): - """Test the `add_data_object` method""" - client.batch.add_data_object( - data_object={}, - class_name="Test", - uuid=uuid, - vector=vector, - ) - response = client.batch.create_objects() - assert has_batch_errors(response) is False, str(response) - - -def test_add_data_object_and_get_class_shards_readiness(client: weaviate.Client): - """Test the `add_data_object` method""" - client.batch.add_data_object( - data_object={}, - class_name="Test", - ) - response = client.batch.create_objects() - assert has_batch_errors(response) is False, str(response) - statuses = client.batch._get_shards_readiness(Shard(class_name="Test")) - assert len(statuses) == 1 - assert statuses[0] - - -def test_add_data_object_with_tenant_and_get_class_shards_readiness(): - """Test the `add_data_object` method""" - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create_class( - { - "class": "Test", - "vectorizer": "none", - "multiTenancyConfig": { - "enabled": True, - }, - } - ) - client.schema.add_class_tenants("Test", [Tenant("tenant1"), Tenant("tenant2")]) - client.batch.add_data_object( - data_object={}, - class_name="Test", - tenant="tenant1", - ) - response = client.batch.create_objects() - assert has_batch_errors(response) is False, str(response) - statuses = client.batch._get_shards_readiness(Shard(class_name="Test", tenant="tenant1")) - assert len(statuses) == 1 - assert statuses[0] - - -@pytest.mark.parametrize( - "objs,where", - [ - ( - [ - {"name": "zero"}, - ], - { - "path": ["name"], - "operator": "NotEqual", - "valueText": "one", - }, - ), - ( - [ - {"name": "one"}, - ], - { - "path": ["name"], - "operator": "Equal", - "valueText": "one", - }, - ), - ( - [{"name": "two"}, {"name": "three"}], - { - "path": ["name"], - "operator": "ContainsAny", - "valueTextArray": ["two", "three"], - }, - ), - ( - [ - {"names": ["Tim", "Tom"], "name": "four"}, - ], - { - "path": ["names"], - "operator": "ContainsAll", - "valueTextArray": ["Tim", "Tom"], - }, - ), - ( - [ - {"names": ["Tim", "Tom"], "name": "five"}, - ], - { - "operator": "And", - "operands": [ - { - "path": ["names"], - "operator": "ContainsAll", - "valueTextArray": ["Tim", "Tom"], - }, - { - "path": ["name"], - "operator": "Equal", - "valueText": "five", - }, - ], - }, - ), - ( - [{"name": "six"}, {"name": "seven"}], - { - "operator": "Or", - "operands": [ - { - "path": ["name"], - "operator": "Equal", - "valueText": "six", - }, - { - "path": ["name"], - "operator": "Equal", - "valueText": "seven", - }, - ], - }, - ), - ( - [ - {"name": "eight"}, - ], - { - "path": ["name"], - "operator": "Like", - "valueText": "eig*", - }, - ), - ], -) -def test_delete_objects_successes(client: weaviate.Client, objs: List[dict], where: dict): - with client.batch as batch: - for obj in objs: - batch.add_data_object(data_object=obj, class_name="Test") - - with client.batch as batch: - batch.delete_objects( - "Test", - where=where, - ) - res = client.data_object.get() - names = [obj["properties"]["name"] for obj in res["objects"]] - for obj in objs: - assert obj.get("name") not in names - - -def test_delete_objects_errors(client: weaviate.Client): - with pytest.raises(ValueError) as error: - with client.batch as batch: - batch.delete_objects( - "test", - where={ - "path": ["name"], - "operator": "ContainsAny", - "valueText": ["four"], - }, - ) - assert ( - error.value.args[0] - == f"Operator 'ContainsAny' is not supported for value type 'valueText'. Supported value types are: {VALUE_ARRAY_TYPES}" - ) - - where = { - "path": ["name"], - "valueTextArray": ["four"], - } - with pytest.raises(ValueError) as error: - with client.batch as batch: - batch.delete_objects( - "Test", - where=where, - ) - assert ( - error.value.args[0] == f"Where filter is missing required field `operator`. Given: {where}" - ) - - with pytest.raises(ValueError) as error: - with client.batch as batch: - batch.delete_objects( - "Test", - where={ - "path": ["name"], - "operator": "Wrong", - "valueText": ["four"], - }, - ) - assert ( - error.value.args[0] - == f"Operator Wrong is not allowed. Allowed operators are: {WHERE_OPERATORS}" - ) - - -@pytest.mark.parametrize("from_object_uuid", [uuid.uuid4(), str(uuid.uuid4()), uuid.uuid4().hex]) -@pytest.mark.parametrize("to_object_uuid", [uuid.uuid4().hex, uuid.uuid4(), str(uuid.uuid4())]) -@pytest.mark.parametrize("to_object_class_name", [None, "Test"]) -def test_add_reference( - client: weaviate.Client, - from_object_uuid: UUID, - to_object_uuid: UUID, - to_object_class_name: Optional[str], -): - """Test the `add_reference` method""" - - # create the 2 objects first - client.data_object.create( - data_object={}, - class_name="Test", - uuid=from_object_uuid, - ) - client.data_object.create( - data_object={}, - class_name="Test", - uuid=to_object_uuid, - ) - - client.batch.add_reference( - from_object_uuid=from_object_uuid, - from_object_class_name="Test", - from_property_name="test", - to_object_uuid=to_object_uuid, - to_object_class_name=to_object_class_name, - ) - - response = client.batch.create_references() - assert has_batch_errors(response) is False, str(response) - - -def test_add_object_batch_with_tenant(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - - # create two classes and add 5 tenants each - class_names = ["BatchTestMultiTenant1", "BatchTestMultiTenant2"] - for name in class_names: - client.schema.create_class( - { - "class": name, - "vectorizer": "none", - "properties": [ - {"name": "tenantAsProp", "dataType": ["text"]}, - ], - "multiTenancyConfig": {"enabled": True}, - }, - ) - client.schema.add_class_tenants(name, [Tenant("tenant" + str(i)) for i in range(5)]) - - nr_objects = 100 - objects = [] - with client.batch() as batch: - for i in range(nr_objects): - obj_uuid = uuid.uuid4() - objects.append((obj_uuid, class_names[i % 2], "tenant" + str(i % 5))) - batch.add_data_object( - class_name=class_names[i % 2], - tenant="tenant" + str(i % 5), - data_object={"tenantAsProp": "tenant" + str(i % 5)}, - uuid=obj_uuid, - ) - - for obj in objects: - retObj = client.data_object.get_by_id(obj[0], class_name=obj[1], tenant=obj[2]) - assert retObj["properties"]["tenantAsProp"] == obj[2] - - # test batch delete with wrong tenant id - with client.batch() as batch: - batch.delete_objects( - class_name=objects[0][1], - where={ - "path": ["tenantAsProp"], - "operator": "Equal", - "valueString": objects[0][2], - }, - tenant=objects[1][2], - ) - - retObj = client.data_object.get_by_id( - objects[0][0], class_name=objects[0][1], tenant=objects[0][2] - ) - assert retObj["properties"]["tenantAsProp"] == objects[0][2] - - # test batch delete with correct tenant id - with client.batch() as batch: - batch.delete_objects( - class_name=objects[0][1], - where={ - "path": ["tenantAsProp"], - "operator": "Equal", - "valueString": objects[0][2], - }, - tenant=objects[0][2], - ) - - retObj = client.data_object.get_by_id( - objects[0][0], class_name=objects[0][1], tenant=objects[0][2] - ) - assert retObj is None - - for name in class_names: - client.schema.delete_class(name) - - -def test_add_ref_batch_with_tenant(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - - # create two classes and add 5 tenants each - class_names = ["BatchRefTestMultiTenant0", "BatchRefTestMultiTenant1"] - client.schema.create_class( - { - "class": class_names[0], - "vectorizer": "none", - "multiTenancyConfig": {"enabled": True}, - }, - ) - - client.schema.create_class( - { - "class": class_names[1], - "vectorizer": "none", - "properties": [ - {"name": "tenantAsProp", "dataType": ["text"]}, - {"name": "ref", "dataType": [class_names[0]]}, - ], - "multiTenancyConfig": {"enabled": True}, - }, - ) - - for name in class_names: - client.schema.add_class_tenants(name, [Tenant("tenant" + str(i)) for i in range(5)]) - - nr_objects = 100 - objects_class0 = [] - objects_class1 = [] - with client.batch() as batch: - for i in range(nr_objects): - tenant = "tenant" + str(i % 5) - obj_uuid0 = uuid.uuid4() - objects_class0.append(obj_uuid0) - batch.add_data_object( - class_name=class_names[0], tenant=tenant, data_object={}, uuid=obj_uuid0 - ) - - obj_uuid1 = uuid.uuid4() - objects_class1.append((obj_uuid1, "tenant" + str(i % 5))) - batch.add_data_object( - class_name=class_names[1], - tenant=tenant, - data_object={"tenantAsProp": tenant}, - uuid=obj_uuid1, - ) - - # add refs between classes for all tenants - batch.add_reference( - from_property_name="ref", - from_object_class_name=class_names[1], - from_object_uuid=obj_uuid1, - to_object_class_name=class_names[0], - to_object_uuid=obj_uuid0, - tenant=tenant, - ) - - for i, obj in enumerate(objects_class1): - ret_obj = client.data_object.get_by_id(obj[0], class_name=class_names[1], tenant=obj[1]) - assert ret_obj["properties"]["tenantAsProp"] == obj[1] - assert ( - ret_obj["properties"]["ref"][0]["beacon"] - == f"weaviate://localhost/{class_names[0]}/{objects_class0[i]}" - ) - - for name in reversed(class_names): - client.schema.delete_class(name) - - -def test_add_nested_object_with_batch(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - - client.schema.create_class( - { - "class": "BatchTestNested", - "vectorizer": "none", - "properties": [ - { - "name": "nested", - "dataType": ["object"], - "nestedProperties": [ - {"name": "name", "dataType": ["text"]}, - {"name": "names", "dataType": ["text[]"]}, - ], - } - ], - }, - ) - - uuid_ = uuid.uuid4() - with client.batch as batch: - batch.add_data_object( - class_name="BatchTestNested", - data_object={"nested": {"name": "nested", "names": ["nested1", "nested2"]}}, - uuid=uuid_, - ) - - obj = client.data_object.get_by_id(uuid_, class_name="BatchTestNested") - assert obj["properties"]["nested"] == {"name": "nested", "names": ["nested1", "nested2"]} - - -def test_add_1000_objects_with_async_indexing_and_wait(): - client = weaviate.Client("http://localhost:8090") - client.schema.delete_all() - - client.schema.create_class( - { - "class": "BatchTestAsync", - "vectorizer": "none", - "properties": [ - { - "name": "text", - "dataType": ["text"], - } - ], - }, - ) - nr_objects = 1000 - - with client.batch as batch: - for i in range(nr_objects): - batch.add_data_object( - class_name="BatchTestAsync", - data_object={"text": "text" + str(i)}, - vector=[float((j + i) % nr_objects) / nr_objects for j in range(nr_objects)], - ) - client.batch.wait_for_vector_indexing() - res = client.query.aggregate("BatchTestAsync").with_meta_count().do() - assert res["data"]["Aggregate"]["BatchTestAsync"][0]["meta"]["count"] == nr_objects - assert client.schema.get_class_shards("BatchTestAsync")[0]["status"] == "READY" - assert client.schema.get_class_shards("BatchTestAsync")[0]["vectorQueueSize"] == 0 - - -def test_add_10000_objects_with_async_indexing_and_dont_wait(): - client = weaviate.Client("http://localhost:8090") - client.schema.delete_all() - - client.schema.create_class( - { - "class": "BatchTestAsync", - "vectorizer": "none", - "properties": [ - { - "name": "text", - "dataType": ["text"], - } - ], - }, - ) - # async batches are 1000 objects big and lower numbers have a one-minute timeout where weaviate just waits. - # Carefully select the numbers so tests don't take forever - nr_objects = 2000 - with client.batch as batch: - for i in range(nr_objects): - batch.add_data_object( - class_name="BatchTestAsync", - data_object={"text": "text" + str(i)}, - vector=[float((j + i) % nr_objects) / nr_objects for j in range(nr_objects)], - ) - shard_status = client.schema.get_class_shards("BatchTestAsync") - assert shard_status[0]["status"] == "INDEXING" - assert shard_status[0]["vectorQueueSize"] >= 0 - res = client.query.aggregate("BatchTestAsync").with_meta_count().do() - assert res["data"]["Aggregate"]["BatchTestAsync"][0]["meta"]["count"] == nr_objects - - -def test_add_2000_tenant_objects_with_async_indexing_and_wait_for_all(): - client = weaviate.Client("http://localhost:8090") - client.schema.delete_all() - - client.schema.create_class( - { - "class": "BatchTestAsync", - "vectorizer": "none", - "multiTenancyConfig": { - "enabled": True, - }, - "properties": [ - { - "name": "text", - "dataType": ["text"], - } - ], - }, - ) - tenants = [Tenant("tenant" + str(i)) for i in range(2)] - client.schema.add_class_tenants("BatchTestAsync", tenants) - nr_objects = 2000 - - with client.batch as batch: - for i in range(nr_objects): - batch.add_data_object( - class_name="BatchTestAsync", - data_object={"text": "text" + str(i)}, - vector=[float((j + i) % nr_objects) / nr_objects for j in range(nr_objects)], - tenant=tenants[i % len(tenants)].name, - ) - - client.batch.wait_for_vector_indexing() - for tenant in tenants: - res = ( - client.query.aggregate("BatchTestAsync").with_meta_count().with_tenant(tenant.name).do() - ) - assert res["data"]["Aggregate"]["BatchTestAsync"][0]["meta"]["count"] == nr_objects / len( - tenants - ) - for shard in client.schema.get_class_shards("BatchTestAsync"): - assert shard["status"] == "READY" - assert shard["vectorQueueSize"] == 0 - - -def test_add_1000_tenant_objects_with_async_indexing_and_wait_for_only_one(): - client = weaviate.Client("http://localhost:8090") - client.schema.delete_all() - - client.schema.create_class( - { - "class": "BatchTestAsync", - "vectorizer": "none", - "multiTenancyConfig": { - "enabled": True, - }, - "properties": [ - { - "name": "text", - "dataType": ["text"], - } - ], - }, - ) - tenants = [Tenant("tenant" + str(i)) for i in range(2)] - client.schema.add_class_tenants("BatchTestAsync", tenants) - nr_objects = 1001 - with client.batch as batch: - for i in range(nr_objects): - batch.add_data_object( - class_name="BatchTestAsync", - data_object={"text": "text" + str(i)}, - vector=[float((j + i) % nr_objects) / nr_objects for j in range(nr_objects)], - tenant=tenants[0].name if i < 1000 else tenants[1].name, - ) - - client.batch.wait_for_vector_indexing( - shards=[Shard(class_name="BatchTestAsync", tenant=tenants[0].name)] - ) - for tenant in tenants: - res = ( - client.query.aggregate("BatchTestAsync").with_meta_count().with_tenant(tenant.name).do() - ) - assert ( - res["data"]["Aggregate"]["BatchTestAsync"][0]["meta"]["count"] == 1000 - if tenant.name == tenants[0].name - else 1 - ) - for shard in client.schema.get_class_shards("BatchTestAsync"): - if shard["name"] == tenants[0].name: - assert shard["status"] == "READY" - assert shard["vectorQueueSize"] == 0 - else: - assert shard["status"] == "INDEXING" - assert shard["vectorQueueSize"] >= 0 diff --git a/integration_v3/test_classification.py b/integration_v3/test_classification.py deleted file mode 100644 index b375b38c0..000000000 --- a/integration_v3/test_classification.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest - -import weaviate - -schema = { - "classes": [ - { - "class": "Label", - "description": "a label describing a message", - "properties": [ - { - "name": "name", - "description": "The name of this label", - "dataType": ["string"], - }, - { - "name": "description", - "description": "The description of this label", - "dataType": ["text"], - }, - ], - }, - { - "class": "Message", - "description": "a message from written by a person", - "properties": [ - { - "name": "content", - "description": "The content of the message", - "dataType": ["text"], - }, - { - "name": "labeled", - "description": "The label assigned to this message", - "dataType": ["Label"], - }, - ], - }, - ] -} - - -@pytest.fixture(scope="module") -def client(): - client = weaviate.Client("http://localhost:8080") - client.schema.create(schema) - yield client - client.schema.delete_all() - - -def test_contextual(client: weaviate.Client): - # Create labels - client.data_object.create( - {"name": "positive", "description": "A positive, good, happy or supporting message."}, - "Label", - ) - client.data_object.create( - {"name": "negative", "description": "A negative, bad, sad or disrupting message."}, "Label" - ) - - client.data_object.create( - { - "content": "ALERT: So now we find out that the entire opponent “hit squad” illegally wiped their phones clean just prior to the investigation of them, all using the same really dumb reason for this “accident”, just like other people smashing her phones with a hammer, & DELETING THEIR EMAILS!" - }, - "Message", - ) - client.data_object.create( - { - "content": "I'm so happy, proud and excited to be a part of this community for the rest of my days." - }, - "Message", - ) - client.data_object.create( - {"content": "thank you for reminding the world of our cause"}, "Message" - ) - - client.classification.schedule().with_type("text2vec-contextionary-contextual").with_class_name( - "Message" - ).with_based_on_properties(["content"]).with_classify_properties( - ["labeled"] - ).with_wait_for_completion().do() - - result = ( - client.query.get("Message", ["content", "labeled {... on Label {name description}}"]) - .with_additional(["id", "classification{basedOn, id}"]) - .do() - ) - labeled_messages = result["data"]["Get"]["Message"] - for message in labeled_messages: - assert message["labeled"] is not None - assert message["_additional"]["id"] is not None diff --git a/integration_v3/test_cluster.py b/integration_v3/test_cluster.py deleted file mode 100644 index 3852b6130..000000000 --- a/integration_v3/test_cluster.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Dict, Any - -import pytest - -import weaviate -from weaviate.util import parse_version_string - -NODE_NAME = "node1" -NUM_OBJECT = 10 - - -def schema(class_name: str) -> Dict[str, Any]: - return { - "classes": [ - { - "class": class_name, - "properties": [ - {"dataType": ["string"], "name": "stringProp"}, - {"dataType": ["int"], "name": "intProp"}, - ], - } - ] - } - - -@pytest.fixture(scope="module") -def client(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - yield client - client.schema.delete_all() - - -def test_get_nodes_status_without_data(client: weaviate.Client): - """get nodes status without data""" - resp = client.cluster.get_nodes_status(output="verbose") - assert len(resp) == 1 - assert "gitHash" in resp[0] - assert resp[0]["name"] == NODE_NAME - assert resp[0]["shards"] is None # no class given - assert resp[0]["stats"]["objectCount"] == 0 - assert resp[0]["stats"]["shardCount"] == 0 - assert resp[0]["status"] == "HEALTHY" - assert "version" in resp[0] - - -def test_get_nodes_status_with_data(client: weaviate.Client): - """get nodes status with data""" - class_name1 = "ClassA" - uncap_class_name1 = "classA" - client.schema.create(schema(class_name1)) - for i in range(NUM_OBJECT): - client.data_object.create({"stringProp": f"object-{i}", "intProp": i}, class_name1) - - class_name2 = "ClassB" - client.schema.create(schema(class_name2)) - for i in range(NUM_OBJECT * 2): - client.data_object.create({"stringProp": f"object-{i}", "intProp": i}, class_name2) - - # server behaviour of resp.stats.objectCount changed by # server behaviour changed by https://github.com/weaviate/weaviate/pull/4203 - - server_is_at_least_124 = parse_version_string( - client._connection._server_version - ) > parse_version_string("1.24") - - resp = client.cluster.get_nodes_status(output="verbose") - assert len(resp) == 1 - assert "gitHash" in resp[0] - assert resp[0]["name"] == NODE_NAME - assert len(resp[0]["shards"]) == 2 - assert resp[0]["stats"]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT * 3 - assert resp[0]["stats"]["shardCount"] == 2 - assert resp[0]["status"] == "HEALTHY" - assert "version" in resp[0] - - shards = sorted(resp[0]["shards"], key=lambda x: x["class"]) - assert shards[0]["class"] == class_name1 - assert shards[0]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT - assert shards[1]["class"] == class_name2 - assert shards[1]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT * 2 - - resp = client.cluster.get_nodes_status(class_name1, output="verbose") - assert len(resp) == 1 - assert "gitHash" in resp[0] - assert resp[0]["name"] == NODE_NAME - assert len(resp[0]["shards"]) == 1 - assert resp[0]["stats"]["shardCount"] == 1 - assert resp[0]["status"] == "HEALTHY" - assert "version" in resp[0] - - assert shards[0]["class"] == class_name1 - assert shards[0]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT - assert resp[0]["stats"]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT - - resp = client.cluster.get_nodes_status(uncap_class_name1, output="verbose") - assert len(resp) == 1 - assert "gitHash" in resp[0] - assert resp[0]["name"] == NODE_NAME - assert len(resp[0]["shards"]) == 1 - assert resp[0]["stats"]["shardCount"] == 1 - assert resp[0]["status"] == "HEALTHY" - assert "version" in resp[0] - - assert shards[0]["class"] == class_name1 - assert shards[0]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT - assert resp[0]["stats"]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT diff --git a/integration_v3/test_crud.py b/integration_v3/test_crud.py deleted file mode 100644 index 322d7e911..000000000 --- a/integration_v3/test_crud.py +++ /dev/null @@ -1,996 +0,0 @@ -import json -import os -import time -from datetime import datetime -from datetime import timezone -from typing import List, Optional, Dict, Union - -import pytest -import uuid - -import weaviate -from weaviate import Tenant -from weaviate.gql.get import LinkTo - - -def get_query_for_group(name): - return ( - """ - { - Get { - Group (where: { - path: ["name"] - operator: Equal - valueText: "%s" - }) { - name - _additional { - id - } - members { - ... on Person { - name - _additional { - id - } - } - } - } - } - } - """ - % name - ) - - -gql_get_sophie_scholl = """ -{ - Get { - Person (where: { - path: ["id"] - operator: Equal - valueString: "594b7827-f795-40d0-aabb-5e0553953dad" - }){ - name - _additional { - id - } - } - } -} -""" - -SHIP_SCHEMA = { - "classes": [ - { - "class": "Ship", - "description": "object", - "properties": [ - {"dataType": ["string"], "description": "name", "name": "name"}, - {"dataType": ["string"], "description": "description", "name": "description"}, - {"dataType": ["int"], "description": "size", "name": "size"}, - ], - } - ] -} - - -@pytest.fixture(scope="module") -def people_schema() -> str: - with open(os.path.join(os.path.dirname(__file__), "people_schema.json"), encoding="utf-8") as f: - return json.load(f) - - -def test_load_scheme(people_schema): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(people_schema) - - assert client.schema.contains() - assert client.schema.contains(people_schema) - - for cls in people_schema["classes"]: - client.schema.delete_class(cls["class"]) - - -@pytest.fixture(scope="module") -def client(people_schema): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(people_schema) - - yield client - client.schema.delete_all() - - -@pytest.mark.parametrize("timeout, error", [(None, TypeError), ((5,), ValueError)]) -def test_timeout_error(timeout, error): - with pytest.raises(error): - weaviate.Client("http://localhost:8080", timeout_config=timeout) - - -@pytest.mark.parametrize("timeout", [(5, 5), 5, 5.0, (5.0, 5.0), (5, 5.0)]) -def test_timeout(people_schema, timeout): - client = weaviate.Client("http://localhost:8080", timeout_config=timeout) - client.schema.delete_all() - client.schema.create(people_schema) - expected_name = "Sophie Scholl" - client.data_object.create( - {"name": expected_name}, "Person", "594b7827-f795-40d0-aabb-5e0553953dad" - ) - time.sleep(0.5) - result = client.query.raw(gql_get_sophie_scholl) - assert result["data"]["Get"]["Person"][0]["name"] == expected_name - client.schema.delete_all() - - -@pytest.mark.parametrize("limit", [None, 1, 5, 20, 50]) -def test_query_get_with_limit(people_schema, limit: Optional[int]): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(people_schema) - - num_objects = 20 - for i in range(num_objects): - with client.batch as batch: - batch.add_data_object({"name": f"name{i}"}, "Person") - batch.flush() - result = client.data_object.get(class_name="Person", limit=limit) - if limit is None or limit >= num_objects: - assert len(result["objects"]) == num_objects - else: - assert len(result["objects"]) == limit - client.schema.delete_all() - - -def test_query_get_with_after(people_schema): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(people_schema) - - num_objects = 20 - for i in range(num_objects): - with client.batch as batch: - batch.add_data_object({"name": f"name{i}"}, "Person") - batch.flush() - - full_results = client.data_object.get(class_name="Person") - for i, person in enumerate(full_results["objects"][:-1]): - results = client.data_object.get(class_name="Person", limit=1, after=person["id"]) - assert full_results["objects"][i + 1]["id"] == results["objects"][0]["id"] - - client.schema.delete_all() - - -@pytest.mark.parametrize("offset", [None, 0, 1, 5, 20, 50]) -def test_query_get_with_offset(people_schema, offset: Optional[int]): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(people_schema) - - num_objects = 20 - for i in range(num_objects): - with client.batch as batch: - batch.add_data_object({"name": f"name{i}"}, "Person") - batch.flush() - result_without_offset = client.data_object.get(class_name="Person") - result_with_offset = client.data_object.get(class_name="Person", offset=offset) - - if offset is None: - assert result_with_offset["objects"] == result_without_offset["objects"] - elif offset >= num_objects: - assert len(result_with_offset["objects"]) == 0 - else: - assert result_with_offset["objects"][:] == result_without_offset["objects"][offset:] - client.schema.delete_all() - - -@pytest.mark.parametrize( - "sort,expected", - [ - ( - {"properties": "name", "order_asc": True}, - ["name" + "{:02d}".format(i) for i in range(0, 20)], - ), - ( - {"properties": "name", "order_asc": False}, - ["name" + "{:02d}".format(i) for i in range(19, -1, -1)], - ), - ( - {"properties": ["name"], "order_asc": [False]}, - ["name" + "{:02d}".format(i) for i in range(19, -1, -1)], - ), - ( - {"properties": ["description", "size", "name"], "order_asc": [False, True, False]}, - ["name05", "name00", "name06", "name01"], - ), - ( - {"properties": ["description", "size", "name"], "order_asc": False}, - ["name09", "name04", "name08", "name03"], - ), - ( - {"properties": ["description", "size", "name"], "order_asc": True}, - ["name10", "name15", "name11", "name16"], - ), - ({"properties": ["description", "size", "name"]}, ["name10", "name15", "name11", "name16"]), - ], -) -def test_query_get_with_sort( - sort: Optional[Dict[str, Union[str, bool, List[bool], List[str]]]], expected: List[str] -): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(SHIP_SCHEMA) - - num_objects = 10 - for i in range(num_objects): - with client.batch as batch: - batch.add_data_object( - { - "name": "name" + "{:02d}".format(i), - "size": i % 5 + 5, - "description": "Super long description", - }, - "Ship", - ) - batch.add_data_object( - { - "name": "name" + "{:02d}".format(i + 10), - "size": i % 5 + 5, - "description": "Short description", - }, - "Ship", - ) - batch.flush() - result = client.data_object.get(class_name="Ship", sort=sort) - - for i, exp in enumerate(expected): - assert exp == result["objects"][i]["properties"]["name"] - client.schema.delete_all() - - -def test_query_data(client: weaviate.Client): - expected_name = "Sophie Scholl" - client.data_object.create( - {"name": expected_name}, "Person", "594b7827-f795-40d0-aabb-5e0553953dad" - ) - time.sleep(2.0) - result = client.query.raw(gql_get_sophie_scholl) - assert result["data"]["Get"]["Person"][0]["name"] == expected_name - - -def test_create_schema(): - client = weaviate.Client("http://localhost:8080") - single_class = { - "class": "Barbecue", - "description": "Barbecue or BBQ where meat and vegetables get grilled", - } - client.schema.create_class(single_class) - prop = { - "dataType": ["string"], - "description": "how hot is the BBQ in C", - "name": "heat", - } - client.schema.property.create("Barbecue", prop) - classes = client.schema.get()["classes"] - found = False - for class_ in classes: - if class_["class"] == "Barbecue": - found = len(class_["properties"]) == 1 - assert found - client.schema.delete_class("Barbecue") - - -def test_replace_and_update(client: weaviate.Client): - """Test updating an object with put (replace) and patch (update).""" - uuid = "28954264-0449-57a2-ade5-e9e08d11f51a" - client.data_object.create({"name": "Someone"}, "Person", uuid) - person = client.data_object.get_by_id(uuid, class_name="Person") - assert person["properties"]["name"] == "Someone" - client.data_object.replace({"name": "SomeoneElse"}, "Person", uuid) - person = client.data_object.get_by_id(uuid, class_name="Person") - assert person["properties"]["name"] == "SomeoneElse" - client.data_object.update({"name": "Anyone"}, "Person", uuid) - person = client.data_object.get_by_id(uuid, class_name="Person") - assert person["properties"]["name"] == "Anyone" - client.data_object.delete(uuid, class_name="Person") - - -def test_crud(client: weaviate.Client): - chemists: List[str] = [] - _create_objects_batch(client) - _create_objects(client, chemists) - time.sleep(2.0) - _create_references(client, chemists) - time.sleep(2.0) - _validate_data_loading(client) - _delete_objects(client, chemists) - - _delete_references(client) - _get_data(client) - - -def _create_objects_batch(local_client: weaviate.Client): - local_client.batch.add_data_object({"name": "John Rawls"}, "Person") - local_client.batch.add_data_object({"name": "Emanuel Kant"}, "Person") - local_client.batch.add_data_object({"name": "Plato"}, "Person") - local_client.batch.add_data_object({"name": "Pull-up"}, "Exercise") - local_client.batch.add_data_object({"name": "Squat"}, "Exercise") - local_client.batch.add_data_object({"name": "Star jump"}, "Exercise") - - local_client.batch.create_objects() - - -def _create_objects(local_client: weaviate.Client, chemists: List[str]): - local_client.data_object.create( - {"name": "Andrew S. Tanenbaum"}, "Person", "28954261-0449-57a2-ade5-e9e08d11f51a" - ) - local_client.data_object.create( - {"name": "Alan Turing"}, "Person", "1c9cd584-88fe-5010-83d0-017cb3fcb446" - ) - local_client.data_object.create( - {"name": "John von Neumann"}, "Person", "b36268d4-a6b5-5274-985f-45f13ce0c642" - ) - local_client.data_object.create( - {"name": "Tim Berners-Lee"}, "Person", "d1e90d26-d82e-5ef8-84f6-ca67119c7998" - ) - local_client.data_object.create( - {"name": "Legends"}, "Group", "2db436b5-0557-5016-9c5f-531412adf9c6" - ) - local_client.data_object.create( - {"name": "Chemists"}, "Group", "577887c1-4c6b-5594-aa62-f0c17883d9bf" - ) - - for name in ["Marie Curie", "Fritz Haber", "Walter White"]: - chemists.append(local_client.data_object.create({"name": name}, "Person")) - - local_time = datetime.now(timezone.utc).astimezone() - local_client.data_object.create( - {"start": local_time.isoformat()}, "Call", "3ab05e06-2bb2-41d1-b5c5-e044f3aa9623" - ) - - -def _create_references(local_client: weaviate.Client, chemists: List[str]): - local_client.data_object.reference.add( - "2db436b5-0557-5016-9c5f-531412adf9c6", - "members", - "b36268d4-a6b5-5274-985f-45f13ce0c642", - from_class_name="Group", - to_class_name="Person", - ) - local_client.data_object.reference.add( - "2db436b5-0557-5016-9c5f-531412adf9c6", - "members", - "1c9cd584-88fe-5010-83d0-017cb3fcb446", - from_class_name="Group", - to_class_name="Person", - ) - - for chemist in chemists: - local_client.batch.add_reference( - "577887c1-4c6b-5594-aa62-f0c17883d9bf", - "Group", - "members", - chemist, - to_object_class_name="Person", - ) - - local_client.batch.create_references() - - -def _validate_data_loading(local_client: weaviate.Client): - legends = local_client.query.raw(get_query_for_group("Legends"))["data"]["Get"] - for member in legends["Group"][0]["members"]: - assert member["name"] in ["John von Neumann", "Alan Turing"] - - group_chemists = local_client.query.raw(get_query_for_group("Chemists"))["data"]["Get"] - for member in group_chemists["Group"][0]["members"]: - assert member["name"] in ["Marie Curie", "Fritz Haber", "Walter White"] - assert len(group_chemists["Group"][0]["members"]) == 3 - - -def _delete_objects(local_client: weaviate.Client, chemists: List[str]): - local_client.data_object.delete( - chemists[2], class_name="Person" - ) # Delete Walter White not a real chemist just a legend - time.sleep(1.1) - assert not local_client.data_object.exists( - chemists[2], class_name="Person" - ), "Thing was not correctly deleted" - - -def _delete_references(local_client: weaviate.Client): - # Test delete reference - prime_ministers_group = local_client.data_object.create({"name": "Prime Ministers"}, "Group") - prime_ministers = [] - for name in ["Wim Kok", "Dries van Agt", "Piet de Jong"]: - prime_ministers.append(local_client.data_object.create({"name": name}, "Person")) - for prime_minister in prime_ministers: - local_client.data_object.reference.add( - prime_ministers_group, - "members", - prime_minister, - from_class_name="Group", - to_class_name="Person", - ) - time.sleep(2.0) - local_client.data_object.reference.delete( - prime_ministers_group, - "members", - prime_ministers[0], - from_class_name="Group", - to_class_name="Person", - ) - time.sleep(2.0) - prime_ministers_group_object = local_client.data_object.get_by_id( - prime_ministers_group, class_name="Group" - ) - assert ( - len(prime_ministers_group_object["properties"]["members"]) == 2 - ), "Reference not deleted correctly" - - -def _get_data(local_client: weaviate.Client): - local_client.data_object.create( - {"name": "George Floyd"}, "Person", "452e3031-bdaa-4468-9980-aed60d0258bf" - ) - time.sleep(2.0) - person = local_client.data_object.get_by_id( - "452e3031-bdaa-4468-9980-aed60d0258bf", - ["interpretation"], - with_vector=True, - class_name="Person", - ) - assert "vector" in person - assert ( - "interpretation" in person["additional"] - ), "additional property _interpretation not in person" - - persons = local_client.data_object.get(with_vector=True) - assert "vector" in persons["objects"][0], "additional property _vector not in persons" - - -def test_add_vector_and_vectorizer(client: weaviate.Client): - """Add objects with and without vector. - - The Vectorizer should create a vector for the object without vector and the given one should be used for the object - with vector. - """ - uuid_without_vector = uuid.uuid4() - uuid_with_vector = uuid.uuid4() - with client.batch(batch_size=2) as batch: - batch.add_data_object({"name": "Some Name"}, "Person", uuid=uuid_without_vector) - batch.add_data_object( - {"name": "Other Name"}, "Person", uuid=uuid_with_vector, vector=[1] * 300 - ) - batch.flush() - - object_with_vector = client.data_object.get_by_id( - uuid_with_vector, - with_vector=True, - class_name="Person", - ) - assert object_with_vector["vector"] == [1] * 300 - - object_without_vector = client.data_object.get_by_id( - uuid_without_vector, - with_vector=True, - class_name="Person", - ) - assert object_without_vector["vector"] != [1] * 300 - - -def test_beacon_refs(people_schema: dict): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(people_schema) - - persons = [] - for i in range(10): - persons.append(uuid.uuid4()) - client.data_object.create({"name": "randomName" + str(i)}, "Person", persons[-1]) - - client.data_object.create({}, "Call", "3ab05e06-2bb2-41d1-b5c5-e044f3aa9623") - - # create refs - for i in range(5): - client.data_object.reference.add( - to_uuid=persons[i], - from_property_name="caller", - from_uuid="3ab05e06-2bb2-41d1-b5c5-e044f3aa9623", - from_class_name="Call", - to_class_name="Person", - ) - - result = client.query.get( - "Call", - [ - "start", - LinkTo(link_on="caller", linked_class="Person", properties=["name"]), - ], - ).do() - callers = result["data"]["Get"]["Call"][0]["caller"] - assert len(callers) == 5 - all_names = [caller["name"] for caller in callers] - assert all("randomName" + str(i) in all_names for i in range(5)) - - -def test_beacon_refs_multiple(people_schema: dict): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create_class( - { - "class": "Person", - "description": "A person such as humans or personality known through culture", - "properties": [ - {"name": "name", "dataType": ["string"]}, - {"name": "age", "dataType": ["int"]}, - {"name": "born_in", "dataType": ["text"]}, - ], - "vectorizer": "none", - } - ) - - client.schema.create_class( - { - "class": "Call", - "description": "A call between two Persons", - "properties": [ - {"name": "caller", "dataType": ["Person"]}, - {"name": "recipient", "dataType": ["Person"]}, - ], - "vectorizer": "none", - } - ) - - persons = [] - for i in range(10): - persons.append(uuid.uuid4()) - client.data_object.create( - {"name": "randomName" + str(i), "age": i, "born_in": "city" + str(i)}, - "Person", - persons[-1], - ) - - call_uuids = [uuid.uuid4(), uuid.uuid4()] - client.data_object.create({}, "Call", call_uuids[0]) - client.data_object.create({}, "Call", call_uuids[1]) - - # create refs - for i in range(4): - client.data_object.reference.add(call_uuids[i % 2], "caller", persons[i], "Call", "Person") - client.data_object.reference.add( - call_uuids[i % 2], "recipient", persons[i + 5], "Call", "Person" - ) - - result = client.query.get( - "Call", - [ - LinkTo(link_on="caller", linked_class="Person", properties=["name", "age"]), - LinkTo(link_on="recipient", linked_class="Person", properties=["born_in", "age"]), - ], - ).do() - call1 = result["data"]["Get"]["Call"][0] - call2 = result["data"]["Get"]["Call"][1] - - # each call has two callers and recipients and caller and recipient should contain different entries - for call in [call1, call2]: - assert len(call["caller"]) == 2 - assert len(call["recipient"]) == 2 - - assert "age" in call["caller"][0] and "name" in call["caller"][0] - assert "age" in call["recipient"][0] and "born_in" in call["recipient"][0] - - -def test_beacon_refs_nested(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create_class( - { - "class": "A", - "properties": [{"name": "nonRef", "dataType": ["string"]}], - "vectorizer": "none", - } - ) - client.schema.create_class( - { - "class": "B", - "properties": [ - {"name": "nonRef", "dataType": ["string"]}, - {"name": "refA", "dataType": ["A"]}, - ], - "vectorizer": "none", - } - ) - client.schema.create_class( - { - "class": "C", - "properties": [ - {"name": "nonRef", "dataType": ["string"]}, - {"name": "refB", "dataType": ["B"]}, - ], - "vectorizer": "none", - } - ) - client.schema.create_class( - { - "class": "D", - "properties": [ - {"name": "nonRef", "dataType": ["string"]}, - {"name": "refC", "dataType": ["C"]}, - {"name": "refB", "dataType": ["B"]}, - ], - "vectorizer": "none", - } - ) - - uuid_a = client.data_object.create({"nonRef": "A"}, "A") - uuid_b = client.data_object.create({"nonRef": "B"}, "B") - client.data_object.reference.add(uuid_b, "refA", uuid_a, "B", "A") - - uuid_c = client.data_object.create({"nonRef": "C"}, "C") - client.data_object.reference.add(uuid_c, "refB", uuid_b, "C", "B") - - uuid_d = client.data_object.create({"nonRef": "D"}, "D") - client.data_object.reference.add(uuid_d, "refC", uuid_c, "D", "C") - client.data_object.reference.add(uuid_d, "refB", uuid_b, "D", "B") - - result = client.query.get( - "D", - [ - "nonRef", - LinkTo( - link_on="refC", - linked_class="C", - properties=[ - "nonRef", - LinkTo( - link_on="refB", - linked_class="B", - properties=[ - "nonRef", - LinkTo(link_on="refA", linked_class="A", properties=["nonRef"]), - ], - ), - ], - ), - LinkTo( - link_on="refB", - linked_class="B", - properties=[ - "nonRef", - LinkTo(link_on="refA", linked_class="A", properties=["nonRef"]), - ], - ), - ], - ).do() - - assert result["data"]["Get"]["D"][0]["refC"][0]["refB"][0]["refA"][0]["nonRef"] == "A" - assert result["data"]["Get"]["D"][0]["refB"][0]["refA"][0]["nonRef"] == "A" - - -def test_tenants(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - tenants = [ - Tenant(name="tenantA"), - Tenant(name="tenantB"), - Tenant(name="tenantC"), - ] - - class_name_document = "Document" - client.schema.create_class( - { - "class": class_name_document, - "properties": [ - {"name": "tenant", "dataType": ["text"]}, - {"name": "title", "dataType": ["text"]}, - ], - "vectorizer": "none", - "multiTenancyConfig": {"enabled": True}, - } - ) - client.schema.add_class_tenants( - class_name=class_name_document, - tenants=tenants, - ) - document_uuids = [ - "00000000-0000-0000-0000-000000000011", - "00000000-0000-0000-0000-000000000022", - "00000000-0000-0000-0000-000000000033", - ] - document_titles = ["GAN", "OpenAI", "SpaceX"] - for i in range(0, len(document_uuids)): - client.data_object.create( - class_name=class_name_document, - uuid=document_uuids[i], - data_object={ - "tenant": tenants[i].name, - "title": document_titles[i], - }, - tenant=tenants[i].name, - ) - documents = client.data_object.get(class_name=class_name_document, tenant=tenants[0].name) - assert len(documents["objects"]) == 1 - - class_name_passage = "Passage" - client.schema.create_class( - { - "class": class_name_passage, - "properties": [ - {"name": "tenant", "dataType": ["text"]}, - {"name": "content", "dataType": ["text"]}, - {"name": "ofDocument", "dataType": ["Document"]}, - ], - "vectorizer": "none", - "multiTenancyConfig": {"enabled": True}, - } - ) - client.schema.add_class_tenants( - class_name=class_name_passage, - tenants=tenants, - ) - - passage_uuids = [ - "00000000-0000-0000-0000-000000000001", - "00000000-0000-0000-0000-000000000002", - "00000000-0000-0000-0000-000000000003", - ] - txts = ["Txt1", "Txt2", "Txt3"] - - for i in range(0, len(passage_uuids)): - client.data_object.create( - class_name=class_name_passage, - uuid=passage_uuids[i], - data_object={ - "content": txts[i], - "tenant": tenants[i].name, - }, - tenant=tenants[i].name, - ) - - for i in range(0, len(passage_uuids)): - passage = client.data_object.get_by_id( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - assert passage["properties"]["tenant"] == tenants[i].name - assert passage["properties"]["content"] == txts[i] - passages = client.data_object.get(class_name=class_name_passage, tenant=tenants[0].name) - assert len(passages["objects"]) == 1 - - for i in range(0, len(passage_uuids)): - client.data_object.replace( - class_name=class_name_passage, - uuid=passage_uuids[i], - data_object={ - "content": txts[len(txts) - i - 1], - "tenant": tenants[i].name, - }, - tenant=tenants[i].name, - ) - - for i in range(0, len(passage_uuids)): - exists = client.data_object.exists( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - assert exists - passage = client.data_object.get_by_id( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - assert passage["properties"]["tenant"] == tenants[i].name - assert passage["properties"]["content"] == txts[len(txts) - i - 1] - - for i in range(0, len(passage_uuids)): - client.data_object.update( - class_name=class_name_passage, - uuid=passage_uuids[i], - data_object={ - "content": tenants[i].name, - }, - tenant=tenants[i].name, - ) - - for i in range(0, len(passage_uuids)): - passage = client.data_object.get_by_id( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - assert passage["properties"]["tenant"] == tenants[i].name - assert passage["properties"]["content"] == tenants[i].name - - # references - for i in range(0, len(passage_uuids)): - client.data_object.reference.add( - passage_uuids[i], - "ofDocument", - document_uuids[i], - from_class_name="Passage", - to_class_name="Document", - tenant=tenants[i].name, - ) - - for i in range(0, len(passage_uuids)): - passage = client.data_object.get_by_id( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - assert len(passage["properties"]["ofDocument"]) == 1 - - for i in range(0, len(passage_uuids)): - client.data_object.reference.update( - passage_uuids[i], - "ofDocument", - document_uuids[i], - from_class_name="Passage", - to_class_names="Document", - tenant=tenants[i].name, - ) - - for i in range(0, len(passage_uuids)): - client.data_object.reference.delete( - passage_uuids[i], - "ofDocument", - document_uuids[i], - from_class_name="Passage", - to_class_name="Document", - tenant=tenants[i].name, - ) - - for i in range(0, len(passage_uuids)): - passage = client.data_object.get_by_id( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - assert len(passage["properties"]["ofDocument"]) == 0 - - for i in range(0, len(passage_uuids)): - client.data_object.delete( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - - for i in range(0, len(passage_uuids)): - exists = client.data_object.exists( - class_name=class_name_passage, - uuid=passage_uuids[i], - tenant=tenants[i].name, - ) - assert not exists - - for i in range(0, len(document_uuids)): - client.data_object.delete( - class_name=class_name_document, - uuid=document_uuids[i], - tenant=tenants[i].name, - ) - - for i in range(0, len(document_uuids)): - exists = client.data_object.exists( - class_name=class_name_document, - uuid=document_uuids[i], - tenant=tenants[i].name, - ) - assert not exists - - -@pytest.mark.parametrize( - "prop_defs,props", - [ - ( - { - "dataType": ["text"], - "name": "name", - }, - { - "name": "test", - }, - ), - ( - { - "dataType": ["text[]"], - "name": "names", - }, - { - "names": ["test1", "test2"], - }, - ), - ( - { - "dataType": ["int"], - "name": "age", - }, - { - "age": 42, - }, - ), - ( - { - "dataType": ["int[]"], - "name": "ages", - }, - { - "ages": [42, 43], - }, - ), - ( - { - "dataType": ["number"], - "name": "height", - }, - { - "height": 1.80, - }, - ), - ( - { - "dataType": ["number[]"], - "name": "heights", - }, - { - "heights": [1.00, 1.80], - }, - ), - ( - { - "dataType": ["boolean"], - "name": "isTall", - }, - { - "isTall": True, - }, - ), - ( - { - "dataType": ["boolean[]"], - "name": "areTall", - }, - { - "areTall": [False, True], - }, - ), - ( - { - "dataType": ["date"], - "name": "birthday", - }, - { - "birthday": "2021-01-01T00:00:00Z", - }, - ), - ( - { - "dataType": ["date[]"], - "name": "birthdays", - }, - { - "birthdays": ["2021-01-01T00:00:00Z", "2021-01-02T00:00:00Z"], - }, - ), - ], -) -def test_nested_object_datatype(prop_defs: dict, props: dict): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create_class( - { - "class": "A", - "properties": [ - {"name": "nested", "dataType": ["object"], "nestedProperties": [prop_defs]}, - ], - "vectorizer": "none", - } - ) - - uuid_ = client.data_object.create({"nested": props}, "A") - obj = client.data_object.get_by_id(uuid_, class_name="A") - assert obj["properties"]["nested"] == props diff --git a/integration_v3/test_graphql.py b/integration_v3/test_graphql.py deleted file mode 100644 index 74c85ef12..000000000 --- a/integration_v3/test_graphql.py +++ /dev/null @@ -1,719 +0,0 @@ -import json -import os -import uuid -from typing import Optional, List, Union - -import pytest -from pytest import FixtureRequest - -import weaviate -from weaviate import Tenant -from weaviate.data.replication import ConsistencyLevel -from weaviate.gql.get import HybridFusion - -schema = { - "classes": [ - { - "class": "Ship", - "description": "object", - "properties": [ - {"dataType": ["string"], "description": "name", "name": "name"}, - {"dataType": ["string"], "description": "description", "name": "description"}, - {"dataType": ["int"], "description": "size", "name": "size"}, - {"dataType": ["number"], "description": "rating", "name": "rating"}, - ], - "vectorizer": "text2vec-contextionary", - } - ] -} - -SHIPS = [ - { - "props": { - "name": "HMS British Name", - "size": 5, - "rating": 0.0, - "description": "Super long description", - }, - "id": uuid.uuid4(), - }, - { - "props": { - "name": "The dragon ship", - "rating": 6.66, - "size": 20, - "description": "Interesting things about dragons", - }, - "id": uuid.uuid4(), - }, - { - "props": { - "name": "Blackbeard", - "size": 43, - "rating": 7.2, - "description": "Background info about movies", - }, - "id": uuid.uuid4(), - }, - { - "props": {"name": "Titanic", "size": 1, "rating": 4.5, "description": "Everyone knows"}, - "id": uuid.uuid4(), - }, - { - "props": { - "name": "Artemis", - "size": 34, - "rating": 9.1, - "description": "Name from some story", - }, - "id": uuid.uuid4(), - }, - { - "props": { - "name": "The Crusty Crab", - "size": 303, - "rating": 10.0, - "description": "sponges, sponges, sponges", - }, - "id": uuid.uuid4(), - }, -] - - -@pytest.fixture(scope="function") -def people_schema() -> str: - with open(os.path.join(os.path.dirname(__file__), "people_schema.json"), encoding="utf-8") as f: - return json.load(f) - - -@pytest.fixture(scope="module") -def client(request): - port = 8080 - opts = parse_client_options(request) - if opts: - if opts.get("cluster"): - port = 8087 - for _, c in enumerate(schema["classes"]): - c["replicationConfig"] = {"factor": 2} - - client = weaviate.Client(f"http://localhost:{port}") - client.schema.delete_all() - client.schema.create(schema) - with client.batch as batch: - for ship in SHIPS: - batch.add_data_object(ship["props"], "Ship", ship["id"]) - - batch.flush() - - yield client - client.schema.delete_all() - - -def parse_client_options(request: FixtureRequest) -> dict: - try: - if isinstance(request.param, dict): - return request.param - except AttributeError: - return - - -def test_get_data(client: weaviate.Client): - """Test GraphQL's Get clause.""" - where_filter = {"path": ["size"], "operator": "LessThan", "valueInt": 10} - result = client.query.get("Ship", ["name", "size"]).with_limit(2).with_where(where_filter).do() - objects = get_objects_from_result(result) - a_found = False - d_found = False - for obj in objects: - if obj["name"] == "HMS British Name": - a_found = True - if obj["name"] == "Titanic": - d_found = True - assert a_found and d_found and len(objects) == 2 - - -def test_get_data_with_where_contains_any(client: weaviate.Client): - """Test GraphQL's Get clause with where filter.""" - where_filter = {"path": ["size"], "operator": "ContainsAny", "valueInt": [5]} - result = client.query.get("Ship", ["name", "size"]).with_where(where_filter).do() - objects = get_objects_from_result(result) - assert len(objects) == 1 and objects[0]["name"] == "HMS British Name" - - -@pytest.mark.parametrize( - "path,operator,value_type_key,value_type_value,name,expected_objects_len", - [ - ( - ["description"], - "ContainsAll", - "valueString", - ["sponges, sponges, sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAll", - "valueText", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAll", - "valueStringArray", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAll", - "valueTextArray", - ["sponges, sponges, sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAll", - "valueStringList", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAll", - "valueTextList", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAny", - "valueString", - ["sponges, sponges, sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAny", - "valueText", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAny", - "valueStringArray", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAny", - "valueTextArray", - ["sponges, sponges, sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAny", - "valueStringList", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - ( - ["description"], - "ContainsAny", - "valueTextList", - ["sponges", "sponges", "sponges"], - "The Crusty Crab", - 1, - ), - (["size"], "ContainsAll", "valueInt", [5], "HMS British Name", 1), - (["size"], "ContainsAll", "valueIntList", [5], "HMS British Name", 1), - (["size"], "ContainsAll", "valueIntArray", [5], "HMS British Name", 1), - (["size"], "ContainsAny", "valueInt", [5], "HMS British Name", 1), - (["size"], "ContainsAny", "valueIntList", [5], "HMS British Name", 1), - (["size"], "ContainsAny", "valueIntArray", [5], "HMS British Name", 1), - (["rating"], "ContainsAll", "valueNumber", [6.66], "The dragon ship", 1), - (["rating"], "ContainsAll", "valueNumberList", [6.66], "The dragon ship", 1), - (["rating"], "ContainsAll", "valueNumberArray", [6.66], "The dragon ship", 1), - (["rating"], "ContainsAny", "valueNumber", [6.66], "The dragon ship", 1), - (["rating"], "ContainsAny", "valueNumberList", [6.66], "The dragon ship", 1), - (["rating"], "ContainsAny", "valueNumberArray", [6.66], "The dragon ship", 1), - (["size"], "Equal", "valueInt", 5, "HMS British Name", 1), - (["size"], "LessThan", "valueInt", 5, "Titanic", 1), - (["size"], "LessThanEqual", "valueInt", 1, "Titanic", 1), - (["size"], "GreaterThan", "valueInt", 300, "The Crusty Crab", 1), - (["size"], "GreaterThanEqual", "valueInt", 303, "The Crusty Crab", 1), - (["description"], "Like", "valueString", "sponges", "The Crusty Crab", 1), - (["description"], "Like", "valueText", "sponges", "The Crusty Crab", 1), - (["rating"], "IsNull", "valueBoolean", True, "irrelevant", 0), - (["rating"], "NotEqual", "valueNumber", 6.66, "irrelevant", 5), - ], -) -def test_get_data_with_where( - client: weaviate.Client, - path: List[str], - operator: str, - value_type_key: str, - value_type_value: Union[List[int], List[str]], - name, - expected_objects_len: int, -): - """Test GraphQL's Get clause with where filter.""" - where_filter = { - "path": path, - "operator": operator, - value_type_key: value_type_value, - } - result = client.query.get("Ship", ["name"]).with_where(where_filter).do() - objects = get_objects_from_result(result) - if expected_objects_len == 0: - assert objects is None - else: - assert len(objects) == expected_objects_len - if expected_objects_len == 1: - assert objects[0]["name"] == name - - -def test_get_data_after(client: weaviate.Client): - full_results = client.query.get("Ship", ["name"]).with_additional(["id"]).do() - for i, ship in enumerate(full_results["data"]["Get"]["Ship"][:-1]): - result = ( - client.query.get("Ship", ["name"]) - .with_additional(["id"]) - .with_limit(1) - .with_after(ship["_additional"]["id"]) - .do() - ) - assert ( - result["data"]["Get"]["Ship"][0]["_additional"]["id"] - == full_results["data"]["Get"]["Ship"][i + 1]["_additional"]["id"] - ) - - -def test_get_data_after_wrong_types(client: weaviate.Client): - with pytest.raises(TypeError): - client.query.get("Ship", ["name"]).with_additional(["id"]).with_limit(1).with_after( - 1234 - ).do() - - -def test_multi_get_data(client: weaviate.Client, people_schema): - """Test GraphQL's MultiGet clause.""" - client.schema.create(people_schema) - client.data_object.create( - { - "name": "John", - }, - "Person", - ) - result = client.query.multi_get( - [ - client.query.get("Ship", ["name"]) - .with_alias("one") - .with_sort({"path": ["name"], "order": "asc"}), - client.query.get("Ship", ["size"]) - .with_alias("two") - .with_sort({"path": ["size"], "order": "asc"}), - client.query.get("Person", ["name"]), - ] - ).do()["data"]["Get"] - assert result["one"][0]["name"] == "Artemis" - assert result["two"][0]["size"] == 1 - assert result["Person"][0]["name"] == "John" - - -def test_aggregate_data(client: weaviate.Client): - """Test GraphQL's Aggregate clause.""" - where_filter = {"path": ["name"], "operator": "Equal", "valueString": "The dragon ship"} - - result = ( - client.query.aggregate("Ship") - .with_where(where_filter) - .with_group_by_filter(["name"]) - .with_fields("groupedBy {value}") - .with_fields("name{count}") - .do() - ) - - aggregation = get_aggregation_from_aggregate_result(result) - assert "groupedBy" in aggregation, "Missing groupedBy" - assert "name" in aggregation, "Missing name property" - - -def test_aggregate_data_with_group_by_and_limit(client: weaviate.Client): - """Test GraphQL's Aggregate clause with group_by and limit.""" - result = ( - client.query.aggregate("Ship") - .with_fields("name{count}") - .with_limit(2) - .with_group_by_filter(["name"]) - .do() - ) - - objects = get_objects_from_aggregate_result(result) - assert len(objects) == 2, "Expected 2 results" - - -def test_aggregate_data_with_just_limit(client: weaviate.Client): - """Test GraphQL's Aggregate clause with only limit. It's idempotent.""" - result = client.query.aggregate("Ship").with_fields("name{count}").with_limit(2).do() - - objects = get_objects_from_aggregate_result(result) - assert objects == [ - {"name": {"count": len(SHIPS)}} - ], f"Expected only meta count for {len(SHIPS)} results" - - -def get_objects_from_result(result): - return result["data"]["Get"]["Ship"] - - -def get_aggregation_from_aggregate_result(result): - return result["data"]["Aggregate"]["Ship"][0] - - -def get_objects_from_aggregate_result(result): - return result["data"]["Aggregate"]["Ship"] - - -@pytest.mark.parametrize("query", ["sponges", "sponges\n"]) -def test_bm25(client: weaviate.Client, query): - result = client.query.get("Ship", ["name"]).with_bm25(query, ["name", "description"]).do() - assert len(result["data"]["Get"]["Ship"]) == 1 - assert result["data"]["Get"]["Ship"][0]["name"] == "The Crusty Crab" - - -def test_bm25_no_result(client: weaviate.Client): - result = client.query.get("Ship", ["name"]).with_bm25("sponges\n", ["name"]).do() - assert len(result["data"]["Get"]["Ship"]) == 0 - - -@pytest.mark.parametrize("query", ["sponges", "sponges\n"]) -@pytest.mark.parametrize("fusion_type", [HybridFusion.RANKED, HybridFusion.RELATIVE_SCORE, None]) -def test_hybrid(client: weaviate.Client, query: str, fusion_type: Optional[HybridFusion]): - """Test hybrid search with alpha=0.5 to have a combination of BM25 and vector search.""" - result = ( - client.query.get("Ship", ["name", "description"]) - .with_hybrid(query, alpha=0.5, fusion_type=fusion_type) - .do() - ) - - # will find more results. "The Crusty Crab" is still first, because it matches with the BM25 search - assert len(result["data"]["Get"]["Ship"]) >= 1 - assert result["data"]["Get"]["Ship"][0]["name"] == "The Crusty Crab" - - -@pytest.mark.parametrize( - "properties,num_results", - [(None, 1), ([], 1), (["description"], 1), (["description", "name"], 1), (["name"], 0)], -) -def test_hybrid_properties( - client: weaviate.Client, properties: Optional[List[str]], num_results: int -): - """Test hybrid search with alpha=0.5 to have a combination of BM25 and vector search.""" - result = ( - client.query.get("Ship", ["name"]) - .with_hybrid("sponges", alpha=0.0, properties=properties) - .do() - ) - - # "The Crusty Crab" has "sponges" in its description, it cannot be found in other properties - if num_results > 0: - assert len(result["data"]["Get"]["Ship"]) >= 1 - - assert result["data"]["Get"]["Ship"][0]["name"] == "The Crusty Crab" - else: - assert len(result["data"]["Get"]["Ship"]) == 0 - - -@pytest.mark.parametrize("autocut,num_results", [(1, 1), (2, 6), (-1, len(SHIPS))]) -def test_autocut(client: weaviate.Client, autocut, num_results): - result = ( - client.query.get("Ship", ["name"]) - .with_hybrid(query="sponges", properties=["name", "description"], alpha=0.5) - .with_autocut(autocut) - .with_limit(len(SHIPS)) - .do() - ) - assert len(result["data"]["Get"]["Ship"]) == num_results - assert result["data"]["Get"]["Ship"][0]["name"] == "The Crusty Crab" - - -def test_group_by(client: weaviate.Client, people_schema): - """Test hybrid search with alpha=0.5 to have a combination of BM25 and vector search.""" - client.schema.delete_all() - client.schema.create(people_schema) - - persons = [] - for i in range(10): - persons.append(uuid.uuid4()) - client.data_object.create({"name": "randomName" + str(i)}, "Person", persons[-1]) - - client.data_object.create({}, "Call", "3ab05e06-2bb2-41d1-b5c5-e044f3aa9623") - client.data_object.create({}, "Call", "3ab05e06-2bb2-41d1-b5c5-e044f3aa9622") - - # create refs - for i in range(5): - client.data_object.reference.add( - to_uuid=persons[i], - from_property_name="caller", - from_uuid="3ab05e06-2bb2-41d1-b5c5-e044f3aa9623" - if i % 2 == 0 - else "3ab05e06-2bb2-41d1-b5c5-e044f3aa9622", - from_class_name="Call", - to_class_name="Person", - ) - - result = ( - client.query.get("Call", ["caller{... on Person{name}}"]) - .with_near_object({"id": "3ab05e06-2bb2-41d1-b5c5-e044f3aa9622"}) - .with_group_by(properties=["caller"], groups=2, objects_per_group=3) - .with_additional("group{hits {_additional{vector}caller{... on Person{name}}}}") - .do() - ) - - # will find more results. "The Crusty Crab" is still first, because it matches with the BM25 search - assert len(result["data"]["Get"]["Call"]) >= 1 - # assert result["data"]["Get"]["Call"][0]["caller"][0]["name"] == "randomName0" - - -@pytest.mark.parametrize( - "client,level", - [ - ({"cluster": True}, ConsistencyLevel.ONE), - ({"cluster": True}, ConsistencyLevel.QUORUM), - ({"cluster": True}, ConsistencyLevel.ALL), - ], - indirect=["client"], -) -def test_consistency_level(client: weaviate.Client, level): - result = ( - client.query.get("Ship", ["name"]) - .with_consistency_level(level) - .with_additional("isConsistent") - .do() - ) - for _, res in enumerate(get_objects_from_result(result)): - assert res["_additional"]["isConsistent"] - - -@pytest.mark.parametrize( - "single,grouped", - [ - ("Describe the following as a Facebook Ad: {review}", None), - (None, "Describe the following as a LinkedIn Ad: {review}"), - ( - "Describe the following as a Twitter Ad: {review}", - "Describe the following as a Mastodon Ad: {review}", - ), - ( - "Describe the following as a Twitter Ad: \n Review: {review} \n Name: {name}", - "Describe the following as a Mastodon Ad: \n Review: {review} \n Name: {name}", - ), - ], -) -def test_generative_openai(single: str, grouped: str): - """Test client credential flow with various providers.""" - api_key = os.environ.get("OPENAI_APIKEY") - if api_key is None: - pytest.skip("No OpenAI API key found.") - - client = weaviate.Client( - "http://localhost:8086", additional_headers={"X-OpenAI-Api-Key": api_key} - ) - client.schema.delete_all() - wine_class = { - "class": "Wine", - "properties": [ - {"name": "name", "dataType": ["string"]}, - {"name": "review", "dataType": ["string"]}, - ], - "moduleConfig": {"generative-openai": {}}, - } - client.schema.create_class(wine_class) - client.data_object.create( - data_object={"name": "Super expensive wine", "review": "Tastes like a fresh ocean breeze"}, - class_name="Wine", - ) - client.data_object.create( - data_object={"name": "cheap wine", "review": "Tastes like forest"}, class_name="Wine" - ) - - result = ( - client.query.get("Wine", ["name", "review"]) - .with_generate(single_prompt=single, grouped_task=grouped) - .do() - ) - assert result["data"]["Get"]["Wine"][0]["_additional"]["generate"]["error"] is None - - grouped_properties = ["review"] - result = ( - client.query.get("Wine", ["name", "review"]) - .with_generate( - single_prompt=single, grouped_task=grouped, grouped_properties=grouped_properties - ) - .do() - ) - assert result["data"]["Get"]["Wine"][0]["_additional"]["generate"]["error"] is None - - -def test_graphql_with_tenant(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - schema_class = { - "class": "GraphQlTenantClass", - "vectorizer": "none", - "multiTenancyConfig": {"enabled": True}, - } - - tenants = ["tenant1", "tenant2"] - client.schema.create_class(schema_class) - client.schema.add_class_tenants(schema_class["class"], [Tenant(tenant) for tenant in tenants]) - - nr_objects = 101 - with client.batch() as batch: - for i in range(nr_objects): - batch.add_data_object( - class_name=schema_class["class"], tenant=tenants[i % 2], data_object={} - ) - - # no results without tenant - results = client.query.get(schema_class["class"]).with_additional("id").do() - assert results["data"]["Get"][schema_class["class"]] is None - assert results["errors"] is not None - - # get call with tenant only returns the objects for a given tenant - results = ( - client.query.get(schema_class["class"]).with_additional("id").with_tenant(tenants[0]).do() - ) - assert len(results["data"]["Get"][schema_class["class"]]) == nr_objects // 2 + 1 - - results = ( - client.query.get(schema_class["class"]).with_additional("id").with_tenant(tenants[1]).do() - ) - assert len(results["data"]["Get"][schema_class["class"]]) == nr_objects // 2 - - results = client.query.aggregate(schema_class["class"]).with_meta_count().do() - assert results["data"]["Aggregate"][schema_class["class"]] is None - assert results["errors"] is not None - - results = ( - client.query.aggregate(schema_class["class"]).with_meta_count().with_tenant(tenants[0]).do() - ) - assert ( - int(results["data"]["Aggregate"][schema_class["class"]][0]["meta"]["count"]) - == nr_objects // 2 + 1 - ) - - results = ( - client.query.aggregate(schema_class["class"]).with_meta_count().with_tenant(tenants[1]).do() - ) - assert ( - int(results["data"]["Aggregate"][schema_class["class"]][0]["meta"]["count"]) - == nr_objects // 2 - ) - - -def test_graphql_with_nested_object(): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create_class( - { - "class": "NestedObjectClass", - "vectorizer": "none", - "properties": [ - { - "name": "nested", - "dataType": ["object"], - "nestedProperties": [ - { - "name": "name", - "dataType": ["text"], - }, - { - "name": "names", - "dataType": ["text[]"], - }, - { - "name": "age", - "dataType": ["int"], - }, - { - "name": "ages", - "dataType": ["int[]"], - }, - { - "name": "weight", - "dataType": ["number"], - }, - { - "name": "weights", - "dataType": ["number[]"], - }, - { - "name": "isAlive", - "dataType": ["boolean"], - }, - { - "name": "areAlive", - "dataType": ["boolean[]"], - }, - { - "name": "date", - "dataType": ["date"], - }, - { - "name": "dates", - "dataType": ["date[]"], - }, - { - "name": "uuid", - "dataType": ["uuid"], - }, - { - "name": "uuids", - "dataType": ["uuid[]"], - }, - ], - } - ], - } - ) - data = { - "name": "nested object", - "names": ["nested", "object"], - "age": 42, - "ages": [42, 43], - "weight": 42.42, - "weights": [42.42, 43.43], - "isAlive": True, - "areAlive": [True, False], - "date": "2021-01-01T00:00:00Z", - "dates": ["2021-01-01T00:00:00Z", "2021-01-02T00:00:00Z"], - "uuid": "00000000-0000-0000-0000-000000000000", - "uuids": ["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000001"], - } - uuid_ = client.data_object.create({"nested": data}, "NestedObjectClass") - - results = client.query.get( - "NestedObjectClass", - [ - "nested { name names age ages weight weights isAlive areAlive date dates uuid uuids } _additional { id }" - ], - ).do() - print(results) - assert results["data"]["Get"]["NestedObjectClass"][0]["nested"] == data - assert results["data"]["Get"]["NestedObjectClass"][0]["_additional"]["id"] == uuid_ diff --git a/integration_v3/test_grcp.py b/integration_v3/test_grcp.py deleted file mode 100644 index 63de48107..000000000 --- a/integration_v3/test_grcp.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Any, Dict, Optional - -import pytest as pytest - -import weaviate - -CLASS1 = { - "class": "Test", - "properties": [ - {"name": "test", "dataType": ["string"]}, - {"name": "abc", "dataType": ["int"]}, - ], -} - -CLASS2 = { - "class": "Test2", - "properties": [ - {"name": "test", "dataType": ["string"]}, - {"name": "abc", "dataType": ["int"]}, - {"name": "ref", "dataType": ["Test"]}, - ], -} -VECTOR = [1.5, 2.5, 3.5] * 100 # match with vectorizer vector length - - -UUID1 = "577887c1-4c6b-5594-aa62-f0c17883d9bf" -UUID2 = "577887c1-4c6b-5594-aa62-f0c17883d9cf" - - -@pytest.mark.parametrize("grpc_port", [50051, None]) -@pytest.mark.parametrize("with_limit", [True, False]) -@pytest.mark.parametrize("additional_props", [None, "id", ["id"], ["id", "vector"]]) -@pytest.mark.parametrize( - "search", - [ - {"vector": VECTOR, "certainty": 0.5}, - {"vector": VECTOR, "distance": 0.5}, - {"vector": VECTOR}, - {"id": UUID2}, - {"id": UUID2, "certainty": 0.5}, - {"id": UUID2, "distance": 0.5}, - {"bm25": ""}, - {"hybrid": ""}, - ], -) -@pytest.mark.parametrize( - "properties", - [ - "test", - ["test", "abc"], - ["test", "ref {... on Test {test abc _additional{id vector}}}"], - ], -) -def test_grcp( - with_limit: bool, additional_props, search: Dict[str, Any], properties, grpc_port: Optional[int] -): - client = weaviate.Client( - "http://localhost:8080", additional_config=weaviate.Config(grpc_port_experimental=grpc_port) - ) - client.schema.delete_all() - - client.schema.create_class(CLASS1) - client.schema.create_class(CLASS2) - - # add objects and references - client.data_object.create({"test": "test"}, "Test", vector=VECTOR) - client.data_object.create({"test": "test", "abc": 5}, "Test", vector=VECTOR, uuid=UUID1) - client.data_object.create({"test": "test", "abc": 5}, "Test2", vector=VECTOR, uuid=UUID2) - client.data_object.reference.add( - from_uuid=UUID2, - to_uuid=UUID1, - from_class_name="Test2", - to_class_name="Test", - from_property_name="ref", - ) - - query = client.query.get("Test2", properties) - - if with_limit: - query.with_limit(10) - - if additional_props is not None: - query.with_additional(additional_props) - - if "vector" in search: - query.with_near_vector(search) - elif "id" in search: - query.with_near_object(search) - elif "concepts" in search: - query.with_near_text(search) - elif "bm25" in search: - query.with_bm25(query="test", properties=["test"]) - elif "hybrid" in search: - query.with_hybrid(query="test", properties=["test"], alpha=0.5, vector=VECTOR) - - result = query.do() - assert "Test2" in result["data"]["Get"] - assert "test" in result["data"]["Get"]["Test2"][0] - - -def test_additional(): - client_grpc = weaviate.Client( - "http://localhost:8080", additional_config=weaviate.Config(grpc_port_experimental=50051) - ) - client_grpc.schema.delete_all() - - client_grpc.schema.create_class(CLASS1) - client_grpc.data_object.create({"test": "test"}, "Test", vector=VECTOR) - client_gql = weaviate.Client( - "http://localhost:8080", additional_config=weaviate.Config(grpc_port_experimental=50052) - ) - - results = [] - for client in [client_gql, client_grpc]: - query = client.query.get("Test").with_additional( - weaviate.AdditionalProperties( - uuid=True, - vector=True, - creationTimeUnix=True, - lastUpdateTimeUnix=True, - distance=True, - ) - ) - result = query.do() - assert "Test" in result["data"]["Get"] - - results.append(result) - - result_gql = results[0]["data"]["Get"]["Test"][0]["_additional"] - result_grpc = results[1]["data"]["Get"]["Test"][0]["_additional"] - - assert sorted(result_gql.keys()) == sorted(result_grpc.keys()) - for key in result_gql.keys(): - assert result_gql[key] == result_grpc[key] - - -def test_grpc_errors(): - client = weaviate.Client( - "http://localhost:8080", additional_config=weaviate.Config(grpc_port_experimental=50051) - ) - classname = CLASS1["class"] - if client.schema.exists(classname): - client.schema.delete_class(classname) - client.schema.create_class(CLASS1) - - client.data_object.create({"test": "test"}, classname) - - # class errors - res = client.query.get(classname + "does_not_exist", ["test"]).do() - assert "errors" in res - assert "data" not in res - - # prop errors - res = client.query.get(classname, ["test", "made_up_prop"]).do() - assert "errors" in res - assert "data" not in res diff --git a/integration_v3/test_injection.py b/integration_v3/test_injection.py deleted file mode 100644 index 21eacf5d5..000000000 --- a/integration_v3/test_injection.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import weaviate -import requests -import json - - -def injection_template(n: int) -> str: - return "Liver" + ("\\" * n) + '"}}){{answer}}}}{payload}#' - - -@pytest.mark.parametrize("n_backslashes", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) -def test_gql_injection(n_backslashes: int) -> None: - client = weaviate.Client(url="http://localhost:8080") - client.schema.delete_class("Question") - client.schema.delete_class("Hacked") - class_obj = { - "class": "Question", - "vectorizer": "text2vec-contextionary", - "properties": [ - {"name": "answer", "dataType": ["string"], "tokenization": "field"}, - {"name": "question", "dataType": ["string"]}, - {"name": "category", "dataType": ["string"]}, - ], - } - - class_obj2 = { - "class": "Hacked", - "vectorizer": "text2vec-contextionary", - "properties": [ - {"name": "answer", "dataType": ["string"]}, - {"name": "question", "dataType": ["string"]}, - {"name": "category", "dataType": ["string"]}, - ], - } - client.schema.create_class(class_obj) - client.schema.create_class(class_obj2) - - resp = requests.get( - "https://raw.githubusercontent.com/weaviate-tutorials/quickstart/main/data/jeopardy_tiny.json" - ) - data = json.loads(resp.text) - - client.batch.configure(batch_size=100) - with client.batch as batch: - for _, d in enumerate(data): - properties = { - "answer": d["Answer"], - "question": d["Question"], - "category": d["Category"], - } - batch.add_data_object(data_object=properties, class_name="Question") - batch.add_data_object(data_object=properties, class_name="Hacked") - - injection_payload = client.query.get("Hacked", ["answer"]).build() - query = client.query.get("Question", ["question", "answer", "category"]).with_where( - { - "path": ["answer"], - "operator": "NotEqual", - "valueText": injection_template(n_backslashes).format(payload=injection_payload[1:]), - } - ) - res = query.do() - assert "Hacked" not in res["data"]["Get"] diff --git a/integration_v3/test_schema.py b/integration_v3/test_schema.py deleted file mode 100644 index 5b89ab627..000000000 --- a/integration_v3/test_schema.py +++ /dev/null @@ -1,242 +0,0 @@ -from typing import Optional - -import pytest -import requests - -import weaviate -from weaviate import Tenant, TenantActivityStatus - - -@pytest.fixture(scope="module") -def client(): - client = weaviate.Client("http://localhost:8080") - yield client - client.schema.delete_all() - - -@pytest.mark.parametrize("replicationFactor", [None, 1]) -def test_create_class_with_implicit_and_explicit_replication_config( - client: weaviate.Client, replicationFactor: Optional[int] -): - single_class = { - "class": "Barbecue", - "description": "Barbecue or BBQ where meat and vegetables get grilled", - "properties": [ - { - "dataType": ["string"], - "description": "how hot is the BBQ in C", - "name": "heat", - }, - ], - } - if replicationFactor is None: - expected_factor = 1 - else: - expected_factor = replicationFactor - single_class["replicationConfig"] = { - "factor": replicationFactor, - } - - client.schema.create_class(single_class) - created_class = client.schema.get("Barbecue") - assert created_class["class"] == "Barbecue" - assert created_class["replicationConfig"]["factor"] == expected_factor - - client.schema.delete_class("Barbecue") - - -@pytest.mark.parametrize("data_type", ["uuid", "uuid[]"]) -def test_uuid_datatype(client: weaviate.Client, data_type: str): - single_class = {"class": "UuidTest", "properties": [{"dataType": [data_type], "name": "heat"}]} - - client.schema.create_class(single_class) - created_class = client.schema.get("uuidTest") - assert created_class["class"] == "UuidTest" - - client.schema.delete_class("UuidTest") - - -@pytest.mark.parametrize("object_", ["object", "object[]"]) -@pytest.mark.parametrize( - "nested", - [ - { - "dataType": ["text"], - "name": "name", - }, - {"dataType": ["text[]"], "name": "names"}, - {"dataType": ["int"], "name": "age"}, - {"dataType": ["int[]"], "name": "ages"}, - {"dataType": ["number"], "name": "weight"}, - {"dataType": ["number[]"], "name": "weights"}, - {"dataType": ["boolean"], "name": "isAlive"}, - {"dataType": ["boolean[]"], "name": "areAlive"}, - {"dataType": ["date"], "name": "birthDate"}, - {"dataType": ["date[]"], "name": "birthDates"}, - {"dataType": ["uuid"], "name": "uuid"}, - {"dataType": ["uuid[]"], "name": "uuids"}, - {"dataType": ["blob"], "name": "blob"}, - { - "dataType": ["object"], - "name": "object", - "nestedProperties": [{"dataType": ["text"], "name": "name"}], - }, - { - "dataType": ["object[]"], - "name": "objects", - "nestedProperties": [{"dataType": ["text"], "name": "name"}], - }, - ], -) -def test_object_datatype(client: weaviate.Client, object_: str, nested: dict): - single_class = { - "class": "ObjectTest", - "properties": [{"dataType": [object_], "name": "heat", "nestedProperties": [nested]}], - } - - client.schema.create_class(single_class) - created_class = client.schema.get("ObjectTest") - assert created_class["class"] == "ObjectTest" - - client.schema.delete_class("ObjectTest") - - -@pytest.mark.parametrize("tokenization", ["word", "whitespace", "lowercase", "field"]) -def test_tokenization(client: weaviate.Client, tokenization): - single_class = { - "class": "TokenTest", - "properties": [{"dataType": ["text"], "name": "heat", "tokenization": tokenization}], - } - client.schema.create_class(single_class) - created_class = client.schema.get("TokenTest") - assert created_class["class"] == "TokenTest" - - client.schema.delete_class("TokenTest") - - -def test_class_exists(client: weaviate.Client): - single_class = {"class": "Exists"} - - client.schema.create_class(single_class) - assert client.schema.exists("Exists") is True - assert client.schema.exists("DoesNotExists") is False - - client.schema.delete_class("Exists") - assert client.schema.exists("Exists") is False - - -def test_schema_keys(client: weaviate.Client): - single_class = { - "class": "Author", - "properties": [ - { - "indexFilterable": False, - "indexSearchable": False, - "dataType": ["text"], - "name": "name", - } - ], - } - client.schema.create_class(single_class) - assert client.schema.exists("Author") - - -def test_class_tenants(client: weaviate.Client): - class_name = "MultiTenancySchemaTest" - uncap_class_name = "multiTenancySchemaTest" - single_class = {"class": class_name, "multiTenancyConfig": {"enabled": True}} - client.schema.delete_all() - client.schema.create_class(single_class) - assert client.schema.exists(class_name) - - tenants = [ - Tenant(name="Tenant1"), - Tenant(name="Tenant2"), - Tenant(name="Tenant3"), - Tenant(name="Tenant4"), - ] - client.schema.add_class_tenants(class_name, tenants[:2]) - client.schema.add_class_tenants(uncap_class_name, tenants[2:]) - tenants_get = client.schema.get_class_tenants(class_name) - assert len(tenants_get) == len(tenants) - - client.schema.remove_class_tenants(class_name, ["Tenant2", "Tenant4"]) - client.schema.remove_class_tenants(uncap_class_name, ["Tenant1"]) - tenants_get = client.schema.get_class_tenants(uncap_class_name) - assert len(tenants_get) == 1 - - -def test_update_schema_with_no_properties(client: weaviate.Client): - single_class = {"class": "NoProperties"} - - requests.post("http://localhost:8080/v1/schema", json=single_class) - assert client.schema.exists("NoProperties") - - client.schema.update_config("NoProperties", {"vectorIndexConfig": {"ef": 64}}) - assert client.schema.exists("NoProperties") - - client.schema.delete_class("NoProperties") - assert client.schema.exists("NoProperties") is False - - -def test_class_tenants_activate_deactivate(client: weaviate.Client): - class_name = "MultiTenancyActivateDeactivateSchemaTest" - uncap_class_name = "multiTenancyActivateDeactivateSchemaTest" - single_class = {"class": class_name, "multiTenancyConfig": {"enabled": True}} - client.schema.delete_all() - client.schema.create_class(single_class) - assert client.schema.exists(class_name) - - tenants = [ - Tenant(name="Tenant1"), - Tenant(activity_status=TenantActivityStatus.COLD, name="Tenant2"), - Tenant(name="Tenant3"), - ] - client.schema.add_class_tenants(class_name, tenants) - tenants_get = client.schema.get_class_tenants(class_name) - assert len(tenants_get) == len(tenants) - # below required because tenants are returned in random order by the server - for tenant in tenants_get: - if tenant.name == "Tenant1": - assert tenant.activity_status == TenantActivityStatus.HOT - elif tenant.name == "Tenant2": - assert tenant.activity_status == TenantActivityStatus.COLD - elif tenant.name == "Tenant3": - assert tenant.activity_status == TenantActivityStatus.HOT - else: - raise AssertionError(f"Unexpected tenant: {tenant.name}") - - updated_tenants = [ - Tenant(activity_status=TenantActivityStatus.COLD, name="Tenant1"), - Tenant(activity_status=TenantActivityStatus.HOT, name="Tenant2"), - ] - client.schema.update_class_tenants(class_name, updated_tenants) - tenants_get = client.schema.get_class_tenants(class_name) - assert len(tenants_get) == len(tenants) - # below required because tenants are returned in random order by the server - for tenant in tenants_get: - if tenant.name == "Tenant1": - assert tenant.activity_status == TenantActivityStatus.COLD - elif tenant.name == "Tenant2": - assert tenant.activity_status == TenantActivityStatus.HOT - elif tenant.name == "Tenant3": - assert tenant.activity_status == TenantActivityStatus.HOT - else: - raise AssertionError(f"Unexpected tenant: {tenant.name}") - - updated_tenants = [ - Tenant(activity_status=TenantActivityStatus.COLD, name="Tenant3"), - ] - client.schema.update_class_tenants(uncap_class_name, updated_tenants) - tenants_get = client.schema.get_class_tenants(uncap_class_name) - assert len(tenants_get) == len(tenants) - # below required because tenants are returned in random order by the server - for tenant in tenants_get: - if tenant.name == "Tenant1": - assert tenant.activity_status == TenantActivityStatus.COLD - elif tenant.name == "Tenant2": - assert tenant.activity_status == TenantActivityStatus.HOT - elif tenant.name == "Tenant3": - assert tenant.activity_status == TenantActivityStatus.COLD - else: - raise AssertionError(f"Unexpected tenant: {tenant.name}") diff --git a/integration_v3/test_stress.py b/integration_v3/test_stress.py deleted file mode 100644 index a8a9e8287..000000000 --- a/integration_v3/test_stress.py +++ /dev/null @@ -1,300 +0,0 @@ -import datetime -from dataclasses import dataclass, field -from typing import List, Dict, Optional, Any - -import pytest -import uuid - -import weaviate - -schema = { - "classes": [ - { - "class": "Paragraph", - "properties": [ - {"dataType": ["text"], "name": "contents"}, - {"dataType": ["Paragraph"], "name": "hasParagraphs"}, - {"dataType": ["Author"], "name": "author"}, - ], - "vectorizer": "none", - }, - { - "class": "Article", - "properties": [ - {"dataType": ["string"], "name": "title"}, - {"dataType": ["Paragraph"], "name": "hasParagraphs"}, - {"dataType": ["date"], "name": "datePublished"}, - {"dataType": ["Author"], "name": "author"}, - {"dataType": ["string"], "name": "somestring"}, - {"dataType": ["int"], "name": "counter"}, - ], - "vectorizer": "none", - }, - { - "class": "Author", - "properties": [{"dataType": ["string"], "name": "name"}], - "vectorizer": "none", - }, - ] -} - - -@dataclass(frozen=True) -class Reference: - to_class: str - to_uuid: uuid.UUID - - -@dataclass -class DataObject: - properties: Dict[str, Any] - class_name: str - uuid: uuid - - -@dataclass -class Author: - name: str - uuid: uuid = field(init=False) - class_name: str = field(init=False) - - def to_data_object(self) -> DataObject: - return DataObject({"name": self.name}, self.class_name, self.uuid) - - def __post_init__(self) -> None: - self.uuid = uuid.uuid4() - self.class_name = "Author" - - -@dataclass -class Paragraph: - contents: str - author: Reference - hasParagraphs: Optional[Reference] - uuid: uuid = field(init=False) - class_name: str = field(init=False) - - def to_data_object(self) -> DataObject: - return DataObject({"contents": self.contents}, self.class_name, self.uuid) - - def __post_init__(self) -> None: - self.uuid = uuid.uuid4() - self.class_name = "Paragraph" - - -@dataclass -class Article: - title: str - datePublished: str - somestring: str - counter: int - author: Reference - hasParagraphs: Reference - uuid: uuid = field(init=False) - class_name: str = field(init=False) - - def to_data_object(self) -> DataObject: - return DataObject( - {"title": self.title, "datePublished": self.datePublished}, self.class_name, self.uuid - ) - - def __post_init__(self) -> None: - self.uuid = uuid.uuid4() - self.class_name = "Article" - - -@pytest.mark.parametrize("dynamic", [False]) -@pytest.mark.parametrize("batch_size", [50]) -def test_stress(batch_size, dynamic): - client = weaviate.Client("http://localhost:8080") - client.schema.delete_all() - client.schema.create(schema) - client.batch.configure(batch_size=batch_size, dynamic=dynamic, num_workers=4) - authors = create_authors(1000) - paragraphs = create_paragraphs(1000, authors) - articles = create_articles(1000, authors, paragraphs) - - add_authors(client, authors) - add_paragraphs(client, paragraphs) - add_articles(client, articles) - - client.batch.flush() - __assert_add(client, authors, authors[0].class_name) - __assert_add(client, paragraphs, paragraphs[0].class_name) - __assert_add(client, articles, articles[0].class_name) - - # verify references - for article in articles: - article_weav = client.data_object.get_by_id(article.uuid, class_name=article.class_name) - beacon_article = str(article_weav["properties"]["author"][0]["beacon"]) - assert beacon_article.split("/")[-1] == str(article.author.to_uuid) - beacon_paragraph = str(article_weav["properties"]["hasParagraphs"][0]["beacon"]) - assert beacon_paragraph.split("/")[-1] == str(article.hasParagraphs.to_uuid) - - for paragraph in paragraphs: - article_weav = client.data_object.get_by_id(paragraph.uuid, class_name=paragraph.class_name) - beacon_article = str(article_weav["properties"]["author"][0]["beacon"]) - assert beacon_article.split("/")[-1] == str(paragraph.author.to_uuid) - if paragraph.hasParagraphs is not None: - beacon_paragraph = str(article_weav["properties"]["hasParagraphs"][0]["beacon"]) - assert beacon_paragraph.split("/")[-1] == str(paragraph.hasParagraphs.to_uuid) - else: - assert "hasParagraphs" not in article_weav["properties"] - - client.schema.delete_all() - - -@pytest.fixture( - params=[(batch_size, workers) for workers in [1, 4, 10] for batch_size in [-1, 50, 100]], - ids=[ - f"batch_size{batch_size}, workers {workers})" - for workers in [1, 4, 10] - for batch_size in [-1, 50, 100] - ], -) -def client(request): - local_client = weaviate.Client("http://localhost:8080") - if request.param[0] > 0: - local_client.batch.configure( - batch_size=request.param[0], dynamic=False, num_workers=request.param[1] - ) - else: - local_client.batch.configure(batch_size=10, dynamic=True, num_workers=request.param[1]) - return local_client - - -def run_stress_test(client): - client.schema.delete_all() - client.schema.create(schema) - - authors = create_authors(20000) - paragraphs = create_paragraphs(20000, authors) - articles = create_articles(10000, authors, paragraphs) - - add_authors(client, authors) - add_paragraphs(client, paragraphs) - add_articles(client, articles) - - client.batch.flush() - __assert_add(client, authors, authors[0].class_name) - __assert_add(client, paragraphs, paragraphs[0].class_name) - __assert_add(client, articles, articles[0].class_name) - - client.schema.delete_all() - - -@pytest.mark.profiling -def test_profile_stress(client): - run_stress_test(client) - - -def test_benchmark_stress_test(benchmark, client): - benchmark(test_profile_stress, client) - - -def add_authors(client: weaviate.Client, authors: List[Author]): - for author in authors: - data_object = author.to_data_object() - client.batch.add_data_object( - data_object.properties, data_object.class_name, data_object.uuid - ) - - -def add_paragraphs(client: weaviate.Client, paragraphs: List[Paragraph]): - for paragraph in paragraphs: - data_object = paragraph.to_data_object() - client.batch.add_data_object( - data_object.properties, data_object.class_name, data_object.uuid - ) - client.batch.add_reference( - str(paragraph.uuid), - from_property_name="author", - to_object_uuid=str(paragraph.author.to_uuid), - from_object_class_name="Paragraph", - to_object_class_name="Author", - ) - if paragraph.hasParagraphs is not None: - client.batch.add_reference( - str(paragraph.uuid), - from_property_name="hasParagraphs", - to_object_uuid=str(paragraph.hasParagraphs.to_uuid), - from_object_class_name="Paragraph", - to_object_class_name="Paragraph", - ) - - -def add_articles(client: weaviate.Client, articles: List[Article]): - for article in articles: - data_object = article.to_data_object() - client.batch.add_data_object( - data_object.properties, data_object.class_name, data_object.uuid - ) - client.batch.add_reference( - str(article.uuid), - from_property_name="author", - to_object_uuid=str(article.author.to_uuid), - from_object_class_name="Article", - to_object_class_name="Author", - ) - client.batch.add_reference( - str(article.uuid), - from_property_name="hasParagraphs", - to_object_uuid=str(article.hasParagraphs.to_uuid), - from_object_class_name="Article", - to_object_class_name="Paragraph", - ) - - -def create_authors(num_authors: int) -> List[Author]: - authors: List[Author] = [Author(f"{i}") for i in range(num_authors)] - return authors - - -def create_paragraphs(num_paragraphs: int, authors: List[Author]) -> List[Paragraph]: - paragraphs: List[Paragraph] = [] - for i in range(num_paragraphs): - content: str = f"{i} {i} {i} {i}" - - paragraph_to_reference: Optional[Paragraph] = None - if len(paragraphs) > 0 and i % 2 == 0: - paragraph_to_reference: Paragraph = paragraphs[i % len(paragraphs)] - author_to_reference: Author = authors[i % len(authors)] - paragraphs.append( - Paragraph( - content, - Reference("Author", author_to_reference.uuid), - Reference("Paragraph", paragraph_to_reference.uuid) - if paragraph_to_reference is not None - else None, - ) - ) - return paragraphs - - -def create_articles( - num_articles: int, authors: List[Author], paragraphs: List[Paragraph] -) -> List[Article]: - articles: List[Article] = [] - base_date: datetime.date = datetime.datetime(2023, 12, 9, 7, 1, 34) - for i in range(num_articles): - title: str = f"{i} {i} {i}" - paragraph_to_reference: Paragraph = paragraphs[i % len(paragraphs)] - author_to_reference: Author = authors[i % len(authors)] - date_published: str = (base_date + datetime.timedelta(hours=i)).isoformat() + "Z" - articles.append( - Article( - title, - date_published, - str(i), - i, - Reference("Author", author_to_reference.uuid), - Reference("Paragraph", paragraph_to_reference.uuid), - ) - ) - - return articles - - -def __assert_add(client: weaviate.Client, objects: List[Any], class_name: str) -> None: - result = client.query.aggregate(class_name).with_meta_count().do() - assert len(objects) == result["data"]["Aggregate"][class_name][0]["meta"]["count"] diff --git a/integration_v3/test_timeout.py b/integration_v3/test_timeout.py deleted file mode 100644 index 42c327b17..000000000 --- a/integration_v3/test_timeout.py +++ /dev/null @@ -1,59 +0,0 @@ -import uuid - -import weaviate - -schema = { - "classes": [ - { - "class": "ClassA", - "properties": [ - {"dataType": ["string"], "name": "stringProp"}, - {"dataType": ["int"], "name": "intProp"}, - ], - } - ] -} - - -def test_low_timeout(): - client = weaviate.Client("http://localhost:8080", timeout_config=(1, 1)) - client.schema.delete_all() - client.schema.create(schema) - client.batch.configure(dynamic=True, batch_size=10, num_workers=4) - - num_objects = ( - 5000 # cannot be increased too high, because weaviate can't return that many results - ) - uuids = [] - for i in range(num_objects): - uuids.append(uuid.uuid4()) - client.batch.add_data_object( - {"stringProp": f"object-{i}", "intProp": i}, "ClassA", uuid=uuids[-1] - ) - client.batch.flush() - result = client.query.aggregate("ClassA").with_meta_count().do() - assert num_objects == result["data"]["Aggregate"]["ClassA"][0]["meta"]["count"] - - # update all objects to make sure that updates are processed even when timeouts occur - for i in range(num_objects): - client.batch.add_data_object( - {"stringProp": f"object-{i*2}", "intProp": i * 2}, "ClassA", uuid=uuids[i] - ) - client.batch.flush() - - result = client.query.aggregate("ClassA").with_meta_count().do() - assert num_objects == result["data"]["Aggregate"]["ClassA"][0]["meta"]["count"] - - # check that no additional objects where created, but everything was updated - result = ( - client.query.get("ClassA", ["intProp"]) - .with_additional("id") - .with_limit(num_objects + 10) - .do() - ) - assert num_objects == len(result["data"]["Get"]["ClassA"]) - for obj in result["data"]["Get"]["ClassA"]: - uuid_ind = uuids.index(uuid.UUID(obj["_additional"]["id"])) - assert int(obj["intProp"]) == uuid_ind * 2 - - client.schema.delete_all() diff --git a/mock_tests/test_auth.py b/mock_tests/test_auth.py deleted file mode 100644 index e163d1cf0..000000000 --- a/mock_tests/test_auth.py +++ /dev/null @@ -1,239 +0,0 @@ -import json -import time -import warnings - -import pytest -from werkzeug import Request, Response - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL, CLIENT_ID -from weaviate.exceptions import MissingScopeException - -ACCESS_TOKEN = "HELLO!IamAnAccessToken" -CLIENT_SECRET = "SomeSecret.DontTell" -SCOPE = "IcanBeAnything" -REFRESH_TOKEN = "UseMeToRefreshYourAccessToken" - - -def test_user_password(weaviate_auth_mock): - """Test that client sends username and pw with the correct body to the token endpoint and uses the correct token.""" - - user = "AUsername" - pw = "SomePassWord" - - # note: order matters. If this handler is not called, check of the order of arguments changed - weaviate_auth_mock.expect_request( - "/auth", - data=f"grant_type=password&username={user}&password={pw}&client_id={CLIENT_ID}", - ).respond_with_json( - {"access_token": ACCESS_TOKEN, "expires_in": 500, "refresh_token": REFRESH_TOKEN} - ) - weaviate_auth_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + ACCESS_TOKEN} - ).respond_with_json({}) - - client = weaviate.Client( - MOCK_SERVER_URL, auth_client_secret=weaviate.AuthClientPassword(user, pw) - ) - client.schema.delete_all() # some call that includes authorization - - -def test_bearer_token(weaviate_auth_mock): - """Test that client sends the given bearer token.""" - weaviate_auth_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + ACCESS_TOKEN} - ).respond_with_json({}) - - client = weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthBearerToken(ACCESS_TOKEN, refresh_token=REFRESH_TOKEN), - ) - client.schema.delete_all() # some call that includes authorization - - -def test_client_credentials(weaviate_auth_mock): - """Test that client sends the client credentials to the token endpoint and uses the correct token.""" - weaviate_auth_mock.expect_request("/auth").respond_with_json( - {"access_token": ACCESS_TOKEN, "expires_in": 500} - ) - weaviate_auth_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + ACCESS_TOKEN} - ).respond_with_json({}) - - client = weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthClientCredentials(client_secret=CLIENT_SECRET, scope=SCOPE), - ) - client.schema.delete_all() # some call that includes authorization - - -@pytest.mark.parametrize("header_name", ["Authorization", "authorization"]) -def test_auth_header_priority(recwarn, weaviate_auth_mock, header_name: str): - """Test that auth_client_secret has priority over the auth header.""" - - # testing for warnings can be flaky without this as there are open SSL conections - warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) - - bearer_token = "OTHER TOKEN" - - weaviate_auth_mock.expect_request("/auth").respond_with_json( - {"access_token": ACCESS_TOKEN, "expires_in": 500, "refresh_token": REFRESH_TOKEN} - ) - - def handler(request: Request): - assert request.headers["Authorization"] == "Bearer " + ACCESS_TOKEN - return Response(json.dumps({})) - - weaviate_auth_mock.expect_request("/v1/schema").respond_with_handler(handler) - - client = weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthBearerToken( - access_token=ACCESS_TOKEN, refresh_token="SOMETHING" - ), - additional_headers={header_name: "Bearer " + bearer_token}, - ) - client.schema.delete_all() # some call that includes authorization - - w = [w for w in recwarn if str(w.message).startswith("Auth004")] - assert len(w) == 1 - assert issubclass(w[0].category, UserWarning) - - -def test_refresh(weaviate_auth_mock): - """Test that refresh tokens are used to get a new access token.""" - weaviate_auth_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + ACCESS_TOKEN} - ).respond_with_json({}) - - weaviate_auth_mock.expect_request( - "/auth", - data=f"grant_type=refresh_token&refresh_token={REFRESH_TOKEN}&client_id={CLIENT_ID}", - ).respond_with_json( - {"access_token": ACCESS_TOKEN, "expires_in": 1, "refresh_token": REFRESH_TOKEN} - ) - client = weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthBearerToken( - ACCESS_TOKEN, refresh_token=REFRESH_TOKEN, expires_in=1 - ), - ) - # client gets a new token 5s before expiration - client.schema.delete_all() # some call that includes authorization - - -def test_auth_header_without_weaviate_auth(weaviate_mock): - """Test that setups that use the Authorization header to authorize to non-weaviate servers.""" - bearer_token = "OTHER TOKEN" - weaviate_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + bearer_token} - ).respond_with_json({}) - - client = weaviate.Client( - MOCK_SERVER_URL, - additional_headers={"Authorization": "Bearer " + bearer_token}, - ) - client.schema.delete_all() # some call that includes authorization - - -def test_auth_header_with_catchall_proxy(weaviate_mock, recwarn): - """Test that the client can handle situations in which a proxy returns a catchall page for all requests.""" - weaviate_mock.expect_request("/v1/schema").respond_with_json({}) - weaviate_mock.expect_request("/v1/.well-known/openid-configuration").respond_with_data( - "JsonCannotParseThis" - ) - - client = weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthClientPassword( - username="test-username", password="test-password" - ), - ) - client.schema.delete_all() # some call that includes authorization - - w = [w for w in recwarn if str(w.message).startswith("Auth005")] - assert len(w) == 1 - assert issubclass(w[0].category, UserWarning) - - -def test_missing_scope(weaviate_auth_mock): - with pytest.raises(MissingScopeException): - weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthClientCredentials( - client_secret=CLIENT_SECRET, scope=None - ), - ) - - -def test_token_refresh_timeout(weaviate_auth_mock, recwarn): - """Test that the token refresh background thread can handle timeouts of the auth server.""" - first_request = True - - # This handler lets the refresh request timeout for the first time. Then, the client retries the refresh which - # should succeed. - def handler(request: Request): - nonlocal first_request - if first_request: - time.sleep(6) # Timeout for auth connections is 5s. We need to wait longer - first_request = False - return Response(json.dumps({"access_token": ACCESS_TOKEN + "_1", "expires_in": 31})) - - weaviate_auth_mock.expect_request("/auth").respond_with_handler(handler) - - # This handler only accepts the refreshed token, to make sure that the refresh happened - weaviate_auth_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + ACCESS_TOKEN + "_1"} - ).respond_with_json({}) - - client = weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthBearerToken( - ACCESS_TOKEN, refresh_token=REFRESH_TOKEN, expires_in=1 # force immediate refresh - ), - ) - - time.sleep(9) # sleep longer than the timeout, to give client time to retry - client.schema.delete_all() - - w = [w for w in recwarn if str(w.message).startswith("Con001")] - assert len(w) == 1 - assert issubclass(w[0].category, UserWarning) - - -def test_with_simple_auth_no_oidc_via_api_key(weaviate_mock, recwarn): - weaviate_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + "Super-secret-key"} - ).respond_with_json({}) - - client = weaviate.Client( - MOCK_SERVER_URL, - auth_client_secret=weaviate.AuthApiKey(api_key="Super-secret-key"), - ) - client.schema.delete_all() - - weaviate_mock.check_assertions() - - w = [ - w for w in recwarn if str(w.message).startswith("Auth") or str(w.message).startswith("Con") - ] - assert len(w) == 0 - - -def test_with_simple_auth_no_oidc_via_additional_headers(weaviate_mock, recwarn): - weaviate_mock.expect_request( - "/v1/schema", headers={"Authorization": "Bearer " + "Super-secret-key"} - ).respond_with_json({}) - - client = weaviate.Client( - MOCK_SERVER_URL, - additional_headers={"Authorization": "Bearer " + "Super-secret-key"}, - ) - client.schema.delete_all() - - weaviate_mock.check_assertions() - - w = [ - w for w in recwarn if str(w.message).startswith("Auth") or str(w.message).startswith("Con") - ] - assert len(w) == 0 diff --git a/mock_tests/test_automatic_retries.py b/mock_tests/test_automatic_retries.py deleted file mode 100644 index a5bc93e0f..000000000 --- a/mock_tests/test_automatic_retries.py +++ /dev/null @@ -1,315 +0,0 @@ -import json -import uuid -from typing import Optional - -import pytest -from werkzeug.wrappers import Request, Response - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL -from weaviate.batch.crud_batch import WeaviateErrorRetryConf, BatchResponse -from weaviate.util import check_batch_result - - -@pytest.mark.parametrize( - "error", - [ - {"errors": {"error": [{"message": "I'm an error message"}]}}, - { - "errors": { - "error": [{"message": "I'm an error message"}, {"message": "Another message"}] - } - }, - ], -) -def test_automatic_retry_obs(weaviate_mock, error): - """Tests that all objects are successfully added even if half of them fail.""" - successfully_added = [] - num_failed_requests = 0 - - # Mockserver returns error for half of all objects - def handler(request: Request): - nonlocal num_failed_requests - objects = request.json["objects"] - for j, obj in enumerate(objects): - obj["deprecations"] = None - if j % 2 == 0: - obj["result"] = {} - successfully_added.append(uuid.UUID(obj["id"])) - else: - obj["result"] = error - num_failed_requests += 1 - return Response(json.dumps(objects)) - - weaviate_mock.expect_request("/v1/batch/objects").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL) - added_uuids = [] - batch_size = 4 # Do not change, affects how many failed requests there are - n = ( - 50 * batch_size - ) # multiple of the batch size, otherwise it is difficult to calculate the number of expected errors - client.batch.configure( - batch_size=batch_size, - num_workers=2, - weaviate_error_retries=WeaviateErrorRetryConf(number_retries=3), - dynamic=False, - ) - - with client.batch as batch: - for i in range(n): - added_uuids.append(uuid.uuid4()) - batch.add_data_object({"name": "test" + str(i)}, "test", added_uuids[-1]) - assert len(successfully_added) == n - assert sorted(successfully_added) == sorted(added_uuids) - - # with a batch size of 4, we have 3 failures per batch - assert num_failed_requests == 3 * n / batch_size - - -def test_automatic_retry_refs(weaviate_mock): - """Tests that all references are successfully added even if half of them fail.""" - num_success_requests = 0 - num_failed_requests = 0 - - # Mockserver returns error for half of all objects - def handler(request: Request): - nonlocal num_failed_requests, num_success_requests - objects = request.json - for j, obj in enumerate(objects): - obj["deprecations"] = None - if j % 2 == 0: - obj["result"] = {} - num_success_requests += 1 - else: - obj["result"] = {"errors": {"error": [{"message": "I'm an error message"}]}} - num_failed_requests += 1 - return Response(json.dumps(objects)) - - weaviate_mock.expect_request("/v1/batch/references").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL) - batch_size = 4 # Do not change, affects how many failed requests there are - n = ( - 50 * batch_size - ) # multiple of the batch size, otherwise it is difficult to calculate the number of expected errors - with client.batch( - batch_size=batch_size, - weaviate_error_retries=WeaviateErrorRetryConf(number_retries=3), - num_workers=2, - dynamic=False, - ) as batch: - for _ in range(n): - batch.add_reference( - from_property_name="Property", - from_object_class_name="SomeClass", - from_object_uuid=str(uuid.uuid4()), - to_object_class_name="otherClass", - to_object_uuid=str(uuid.uuid4()), - ) - assert num_success_requests == n - - # with a batch size of 4, we have 3 failures per batch - assert num_failed_requests == 3 * n / batch_size - - -def test_automatic_retry_unsuccessful(weaviate_mock): - """Test automatic retry that cannot add all objects.""" - num_success_requests = 0 - - # Mockserver returns error for half of all objects - def handler(request: Request): - nonlocal num_success_requests - objects = request.json["objects"] - for j, obj in enumerate(objects): - obj["deprecations"] = None - if j % 2 == 0: - obj["result"] = {} - num_success_requests += 1 - else: - obj["result"] = {"errors": {"error": [{"message": "I'm an error message"}]}} - return Response(json.dumps(objects)) - - weaviate_mock.expect_request("/v1/batch/objects").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL) - batch_size = 20 - n = batch_size * 2 - with client.batch( - batch_size=batch_size, - weaviate_error_retries=WeaviateErrorRetryConf(number_retries=1), - num_workers=2, - callback=None, - ) as batch: - for i in range(n): - batch.add_data_object({"name": "test" + str(i)}, "test", uuid.uuid4()) - batch.flush() - # retried 3 times, starting with 200 objects and half off all objects succeed each time - assert num_success_requests == 30 - - -@pytest.mark.parametrize( - "retry_config", - [None, WeaviateErrorRetryConf(number_retries=1), WeaviateErrorRetryConf(number_retries=2)], -) -def test_print_threadsafety(weaviate_mock, capfd, retry_config): - """Test retry with callback and callback threadsafety.""" - num_success_requests = 0 - - # Mockserver returns error for half of all objects - def handler(request: Request): - nonlocal num_success_requests - objects = request.json["objects"] - for j, obj in enumerate(objects): - obj["deprecations"] = None - if j % 2 == 0: - obj["result"] = {} - num_success_requests += 1 - else: - obj["result"] = {"errors": {"error": [{"message": "I'm an error message"}]}} - return Response(json.dumps(objects)) - - weaviate_mock.expect_request("/v1/batch/objects").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL) - - added_uuids = [] - n = 200 * 4 - with client.batch( - batch_size=200, - callback=check_batch_result, - num_workers=4, - weaviate_error_retries=retry_config, - ) as batch: - for i in range(n): - added_uuids.append(uuid.uuid4()) - batch.add_data_object({"name": "test" + str(i)}, "test", added_uuids[-1]) - - retry_factor: float = 1.0 - if retry_config is not None: - retry_factor = 1 / (2 * retry_config.number_retries) - assert num_success_requests == n - n / 2 * retry_factor - - # half of all objects fail => N/2 print statements that end with a newline - print_output, err = capfd.readouterr() - assert print_output.count("\n") == n - num_success_requests - - -@pytest.mark.parametrize( - "retry_config, expected", - [ - (WeaviateErrorRetryConf(number_retries=1, errors_to_include=["include", "maybe"]), 300), - (WeaviateErrorRetryConf(number_retries=1, errors_to_exclude=["reject", "maybe"]), 250), - ], -) -def test_include_error(weaviate_mock, retry_config, expected): - """Test that objects are included/excluded based on their error message""" - num_success_requests = 0 - - # Mockserver returns error for 3/4 of all objects, with different messages for each quarter - def handler(request: Request): - nonlocal num_success_requests - objects = request.json["objects"] - for j, obj in enumerate(objects): - obj["deprecations"] = None - if j % 4 == 0: - obj["result"] = {} - num_success_requests += 1 - elif j % 4 == 1: - obj["result"] = {"errors": {"error": [{"message": "include me"}]}} - elif j % 4 == 2: - obj["result"] = {"errors": {"error": [{"message": "maybe retry maybe not"}]}} - else: - obj["result"] = { - "errors": {"error": [{"message": "reject me"}, {"message": "other error"}]} - } - return Response(json.dumps(objects)) - - weaviate_mock.expect_request("/v1/batch/objects").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL) - - added_uuids = [] - n = 400 * 2 - with client.batch( - batch_size=400, - callback=None, - num_workers=2, - weaviate_error_retries=retry_config, - ) as batch: - for i in range(n): - added_uuids.append(uuid.uuid4()) - batch.add_data_object({"name": "test" + str(i)}, "test", added_uuids[-1]) - - assert num_success_requests == expected - - -def test_callback_for_successful_responses(weaviate_mock, capfd): - """Test that all objects reach teh callback, even when a part of a batch is retried.""" - - # have some objects fail - def handler(request: Request): - objects = request.json["objects"] - for j, obj in enumerate(objects): - obj["deprecations"] = None - if j % 2 == 0: - obj["result"] = {} - else: - obj["result"] = {"errors": {"error": [{"message": "I'm an error message"}]}} - return Response(json.dumps(objects)) - - weaviate_mock.expect_request("/v1/batch/objects").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL) - - def callback_print_all(results: Optional[BatchResponse]): - if results is None: - return - for _ in results: - print("I saw that object") - - added_uuids = [] - n = 200 * 4 - with client.batch( - batch_size=200, - callback=callback_print_all, - num_workers=4, - weaviate_error_retries=WeaviateErrorRetryConf(number_retries=1), - ) as batch: - for i in range(n): - added_uuids.append(uuid.uuid4()) - batch.add_data_object({"name": "test" + str(i)}, "test", added_uuids[-1]) - - # callback output for each object - print_output, err = capfd.readouterr() - assert print_output.count("\n") == n - - -def test_retries_with_tenant(weaviate_no_auth_mock): - tenant = "tenant" - first_try = True - - def handler(request: Request): - nonlocal first_try - objects = request.json["objects"] - for obj in objects: - assert obj["tenant"] == tenant - obj["deprecations"] = None - if first_try == 0: - obj["result"] = {"errors": {"error": [{"message": "I'm an error message"}]}} - first_try = False - else: - obj["result"] = {} - return Response(json.dumps(objects)) - - weaviate_no_auth_mock.expect_request("/v1/batch/objects").respond_with_handler(handler) - - client = weaviate.Client(url=MOCK_SERVER_URL) - - n = 10 - with client.batch( - weaviate_error_retries=WeaviateErrorRetryConf(number_retries=1), - ) as batch: - for i in range(n): - batch.add_data_object({"name": "test" + str(i)}, "test", uuid.uuid4(), tenant=tenant) - weaviate_no_auth_mock.check_assertions() diff --git a/mock_tests/test_batching_manual.py b/mock_tests/test_batching_manual.py deleted file mode 100644 index 010d72a98..000000000 --- a/mock_tests/test_batching_manual.py +++ /dev/null @@ -1,30 +0,0 @@ -import uuid - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL - - -def test_manual_batching_warning_object(recwarn, weaviate_mock): - weaviate_mock.expect_request("/v1/batch/objects").respond_with_json([]) - - client = weaviate.Client(MOCK_SERVER_URL) - - client.batch.configure(batch_size=None, dynamic=False) - client.batch.add_data_object({}, "ExistingClass") - client.batch.create_objects() - - assert any(str(w.message).startswith("Dep002") for w in recwarn) - - -def test_manual_batching_warning_ref(recwarn, weaviate_mock): - weaviate_mock.expect_request("/v1/batch/references").respond_with_json([]) - - client = weaviate.Client(MOCK_SERVER_URL) - client.batch.configure(batch_size=None, dynamic=False) - - client.batch.add_reference( - str(uuid.uuid4()), "NonExistingClass", "existsWith", str(uuid.uuid4()), "OtherClass" - ) - client.batch.create_references() - - assert any(str(w.message).startswith("Dep002") for w in recwarn) diff --git a/mock_tests/test_connection.py b/mock_tests/test_connection.py deleted file mode 100644 index 0e993516f..000000000 --- a/mock_tests/test_connection.py +++ /dev/null @@ -1,66 +0,0 @@ -import json -import time -from typing import Dict - -import pytest -from pytest_httpserver import HTTPServer -from werkzeug import Request, Response - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL, MOCK_IP, MOCK_PORT - - -@pytest.mark.parametrize( - "header", - [ - {}, - {"Authorization": "Bearer test"}, - {"Authorization": "Bearer test", "SomethingElse": "Value"}, - ], -) -def test_additional_headers(weaviate_mock, header: Dict[str, str]): - """Test that client sends given headers.""" - - def handler(request: Request): - assert request.headers["content-type"] == "application/json" - for key, val in header.items(): - assert request.headers[key] == val - return Response(json.dumps({})) - - weaviate_mock.expect_request("/v1/schema").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL, additional_headers=header) - client.schema.delete_all() # some call that includes headers - - -@pytest.mark.parametrize("version,warning", [("1.13", True), ("1.14", False)]) -def test_warning_old_weaviate(recwarn, ready_mock: HTTPServer, version: str, warning: bool): - """Test that we warn if a new client version is using an old weaviate server.""" - ready_mock.expect_request("/v1/meta").respond_with_json({"version": version}) - weaviate.Client(MOCK_SERVER_URL) - - if warning: - assert any(str(w.message).startswith("Dep001") for w in recwarn) - assert any(str(w.message).startswith("Dep004") for w in recwarn) - else: - assert any(str(w.message).startswith("Dep004") for w in recwarn) - - -def test_wait_for_weaviate(httpserver: HTTPServer): - def handler(request: Request): - time.sleep(0.01) - return Response(json.dumps({})) - - def handler_meta(request: Request): - assert time.time() > start_time - 1 - return Response(json.dumps({"version": "1.16"})) - - httpserver.expect_request("/v1/meta").respond_with_handler(handler_meta) - httpserver.expect_request("/v1/.well-known/ready").respond_with_handler(handler) - start_time = time.time() - weaviate.Client(MOCK_SERVER_URL, startup_period=30) - - -def test_user_pw_in_url(weaviate_mock): - """Test that user and pw can be in the url.""" - weaviate.Client("http://user:pw@" + MOCK_IP + ":" + str(MOCK_PORT)) # no exception diff --git a/mock_tests/test_exception.py b/mock_tests/test_exception.py deleted file mode 100644 index 1b2fc6507..000000000 --- a/mock_tests/test_exception.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL -from weaviate.exceptions import ResponseCannotBeDecodedException - - -def test_json_decode_exception_dict(weaviate_mock): - """Tests that JsonDecodeException is raised containing the correct error.""" - - weaviate_mock.expect_request("/v1/schema").respond_with_data("JsonCannotParseThis") - - client = weaviate.Client(MOCK_SERVER_URL) - with pytest.raises(ResponseCannotBeDecodedException) as e: - client.schema.get() - - assert "JsonCannotParseThis" in e.value - - -def test_json_decode_exception_list(weaviate_mock): - """Tests that JsonDecodeException is raised containing the correct error.""" - - weaviate_mock.expect_request("/v1/schema/Test/shards").respond_with_data("JsonCannotParseThis") - - client = weaviate.Client(MOCK_SERVER_URL) - with pytest.raises(ResponseCannotBeDecodedException) as e: - client.schema.get_class_shards("Test") - assert "JsonCannotParseThis" in e.value diff --git a/mock_tests/test_graphql.py b/mock_tests/test_graphql.py deleted file mode 100644 index 9423b8a4a..000000000 --- a/mock_tests/test_graphql.py +++ /dev/null @@ -1,21 +0,0 @@ -from http.server import HTTPServer - -import pytest as pytest - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL - - -@pytest.mark.parametrize( - "version,warning", [("1.16.0", True), ("1.17.2", True), ("1.17.3", False), ("1.18.0", False)] -) -def test_warning_old_weaviate(recwarn, ready_mock: HTTPServer, version: str, warning: bool): - ready_mock.expect_request("/v1/meta").respond_with_json({"version": version}) - client = weaviate.Client(MOCK_SERVER_URL) - - client.query.get("Class", ["Property"]).with_generate(single_prompt="something") - - if warning: - assert any(str(w.message).startswith("Dep003") for w in recwarn) - else: - assert not any(str(w.message).startswith("Dep003") for w in recwarn) diff --git a/mock_tests/test_resend.py b/mock_tests/test_resend.py deleted file mode 100644 index ff2d74a88..000000000 --- a/mock_tests/test_resend.py +++ /dev/null @@ -1,122 +0,0 @@ -import json -import re -import time - -import pytest -import uuid -from requests import ReadTimeout -from werkzeug.wrappers import Request, Response - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL - - -def test_no_retry_on_timeout(weaviate_no_auth_mock): - """Tests that expected timeout exception is raised.""" - - def handler(request: Request): - time.sleep(1.5) # cause timeout - return Response(json.dumps({})) - - weaviate_no_auth_mock.expect_request("/v1/batch/objects").respond_with_handler(handler) - - client = weaviate.Client(MOCK_SERVER_URL, timeout_config=(1, 1)) - - n = 10 - with pytest.raises(ReadTimeout): - with client.batch(batch_size=n, timeout_retries=0, dynamic=False) as batch: - for _ in range(n): - batch.add_data_object({"name": "test"}, "test", uuid.uuid4()) - - -def test_retry_on_timeout(weaviate_no_auth_mock): - """Tests that clients resends objects that haven't been added due to a timeout. - - After the timeout, the client checks if - - An object with the given UUID already exists (using HEAD). Here 50% return that they do NOT exist, eg have to be - resent. - - If an object exists, it is checked if the current version in weaviate is identical to the one that is - sent in the batch. If not, the object in the batch is an update and has to be resent again - - In total 75% are resend. - """ - added_uuids = [] - first_request = True - n = 20 # needs to be divisible by 4 - - def handler_batch_objects(request: Request): - nonlocal first_request, n - if first_request: - assert len(request.json["objects"]) == n # all objects are send the first time - time.sleep(1) # cause timeout - first_request = False - else: - # 75% of objects have to be resent - assert len(request.json["objects"]) == n / 4 * 3 - - return Response(json.dumps([])) - - weaviate_no_auth_mock.expect_request("/v1/batch/objects").respond_with_handler( - handler_batch_objects - ) - - # 50% of objects have not been added - def handler_exists(request: Request): - if added_uuids.index(request.url.split("/")[-1]) % 2 == 0: - return Response(json.dumps({}), status=404) - else: - return Response(json.dumps({}), status=200) - - weaviate_no_auth_mock.expect_request( - re.compile("^/v1/objects/Test/"), method="HEAD" - ).respond_with_handler(handler_exists) - - # 50% of objects are an update to an existing objects and have to be resent - flip = False - - def handler_get_object(request: Request): - nonlocal flip - val = "test" if flip else "other" - flip = not flip - return Response(json.dumps({"properties": {"name": val}})) - - weaviate_no_auth_mock.expect_request( - re.compile("^/v1/objects/Test/"), method="GET" - ).respond_with_handler(handler_get_object) - - client = weaviate.Client(MOCK_SERVER_URL, timeout_config=(1, 1)) - with client.batch(batch_size=n, timeout_retries=1, dynamic=False) as batch: - for _ in range(n): - added_uuids.append(str(uuid.uuid4())) - batch.add_data_object({"name": "test"}, "test", added_uuids[-1]) - weaviate_no_auth_mock.check_assertions() - - -def test_retry_on_timeout_all_succesfull(weaviate_no_auth_mock): - """Test that the client does not resend an empty batch.""" - n = 20 - - # handler only responds once => error if a batch is resent - def handler_batch_objects(request: Request): - nonlocal n - assert len(request.json["objects"]) == n # all objects are send the first time - time.sleep(1) # cause timeout - return Response(json.dumps([])) - - weaviate_no_auth_mock.expect_oneshot_request("/v1/batch/objects").respond_with_handler( - handler_batch_objects - ) - - # return that all objects are already added successful - weaviate_no_auth_mock.expect_request( - re.compile("^/v1/objects/Test/"), method="HEAD" - ).respond_with_response(Response(json.dumps({}), status=200)) - weaviate_no_auth_mock.expect_request( - re.compile("^/v1/objects/Test/"), method="GET" - ).respond_with_json({"properties": {"name": "test"}}) - - client = weaviate.Client(MOCK_SERVER_URL, timeout_config=(1, 1)) - with client.batch(batch_size=n, timeout_retries=1, dynamic=False) as batch: - for _ in range(n): - batch.add_data_object({"name": "test"}, "test", uuid.uuid4()) - weaviate_no_auth_mock.check_assertions() diff --git a/mock_tests/test_schema.py b/mock_tests/test_schema.py deleted file mode 100644 index cc52656d8..000000000 --- a/mock_tests/test_schema.py +++ /dev/null @@ -1,56 +0,0 @@ -import time - -import pytest -from requests import ReadTimeout -from werkzeug.wrappers import Request, Response - -import weaviate -from mock_tests.conftest import MOCK_SERVER_URL - - -def test_schema_timeout_error(weaviate_mock): - """Tests that expected timeout exception is raised.""" - - def handler(request: Request): - time.sleep(1.5) # cause timeout - return Response(status=200) - - weaviate_mock.expect_request("/v1/schema/Test").respond_with_handler(handler) - client = weaviate.Client(MOCK_SERVER_URL, timeout_config=(1, 1)) - - with pytest.raises(ReadTimeout): - client.schema.exists("Test") - - -def test_schema_unknown_status_code(weaviate_mock): - """Tests that expected UnexpectedStatusCodeException exception is raised.""" - - def handler(request: Request): - return Response(status=403) - - weaviate_mock.expect_request("/v1/schema/Test").respond_with_handler(handler) - client = weaviate.Client(MOCK_SERVER_URL) - - with pytest.raises(weaviate.UnexpectedStatusCodeException): - client.schema.exists("Test") - - -def test_schema_exists(weaviate_mock): - """Tests correct behaviour.""" - - def handler(request: Request, status: int): - return Response(status=status) - - weaviate_mock.expect_request("/v1/schema/Exists").respond_with_handler( - lambda r: handler(r, 200) - ) - weaviate_mock.expect_request("/v1/schema/DoesNotExists").respond_with_handler( - lambda r: handler(r, 404) - ) - client = weaviate.Client(MOCK_SERVER_URL) - - assert client.schema.exists("Exists") is True - assert client.schema.exists("DoesNotExists") is False - - assert client.schema.exists("exists") is True - assert client.schema.exists("doesNotExists") is False diff --git a/profiling/test_import_and_query.py b/profiling/test_import_and_query.py index 624e15c88..87d84d1c6 100644 --- a/profiling/test_import_and_query.py +++ b/profiling/test_import_and_query.py @@ -1,14 +1,12 @@ import time -from typing import List import uuid +from typing import List import h5py # type: ignore +from _pytest.fixtures import SubRequest import weaviate import weaviate.classes as wvc - -from _pytest.fixtures import SubRequest - from weaviate.collections.collection import Collection from .conftest import get_file_path @@ -68,25 +66,6 @@ def load_records_v4(collection: Collection, vectors: List[List[float]]) -> None: print(f"V4: Finished writing {len(vectors)} records in {time.time()-start}s") -def load_records_v3(client: weaviate.Client, vectors: List[List[float]], name: str) -> None: - start = time.time() - - client.batch.configure(batch_size=1000, num_workers=2) - - with client.batch as batch: - for i, vector in enumerate(vectors): - data_object = {"i": i} - - batch.add_data_object( - data_object=data_object, - vector=vector, - class_name=name, - uuid=uuid.UUID(int=i), - ) - - print(f"V3: Finished writing {len(vectors)} records in {time.time()-start}s") - - def query_v4( collection: Collection, vectors: List[List[float]], neighbours: List[List[int]], ef: int ) -> None: @@ -107,40 +86,6 @@ def query_v4( ) -def query_v3( - collection: Collection, - client: weaviate.Client, - vectors: List[List[float]], - neighbours: List[List[int]], - ef: int, -) -> None: - collection.config.update(vector_index_config=wvc.config.Reconfigure.VectorIndex.hnsw(ef=ef)) - start = time.time() - recall = 0.0 - - for i, vec in enumerate(vectors): - res = ( - client.query.get(collection.name, ["i _additional{id}"]) - .with_near_vector( - { - "vector": vec, - } - ) - .with_limit(LIMIT) - .do() - ) - res_ids = [ - uuid.UUID(res["_additional"]["id"]).int for res in res["data"]["Get"][collection.name] - ] - ideal_neighbors = set(neighbours[i][:LIMIT]) - - recall += len(ideal_neighbors.intersection(res_ids)) / LIMIT - - print( - f"V3: Querying {len(vectors)} records with ef {ef} in {time.time()-start}s with recall {recall/len(vectors)}" - ) - - def run_v4(file: str, name: str, efc: int, m: int) -> None: sift_file = get_file_path(file) @@ -156,32 +101,6 @@ def run_v4(file: str, name: str, efc: int, m: int) -> None: query_v4(collection, vectors_test, ideal_neighbors, ef) -def run_v3(file: str, name: str, efc: int, m: int) -> None: - sift_file = get_file_path(file) - - f = h5py.File(sift_file) - vectors_import = f["train"] - vectors_test = f["test"] - ideal_neighbors = f["neighbors"] - - client = weaviate.Client(url="http://localhost:8080") - - # use v4 client to create schema to avoid duplicate code - clientv4 = weaviate.connect_to_local() - collection = create_schema(clientv4, name, efc, m, 1, "l2-squared") - load_records_v3(client, vectors_import, name) - for ef in EF_VALUES: - query_v3(collection, client, vectors_test, ideal_neighbors, ef) - - -def test_sift_v3(request: SubRequest) -> None: - run_v3(file="sift-128-euclidean.hdf5", name=request.node.name, efc=128, m=32) - - -def test_dbpedia_v3(request: SubRequest) -> None: - run_v3(file="dbpedia-openai-1000k-angular.hdf5", name=request.node.name, efc=384, m=20) - - def test_sift_v4(request: SubRequest) -> None: run_v4(file="sift-128-euclidean.hdf5", name=request.node.name, efc=128, m=32) diff --git a/requirements-devel.txt b/requirements-devel.txt index 6c50db6f6..127cfab68 100644 --- a/requirements-devel.txt +++ b/requirements-devel.txt @@ -1,4 +1,3 @@ -requests==2.32.3 httpx==0.25.2 validators==0.34.0 authlib==1.3.1 diff --git a/setup.cfg b/setup.cfg index 2f00324fe..c021c6df9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,26 +18,17 @@ dynamic = ["version"] zip_safe = False packages = weaviate - weaviate.connect - weaviate.collections - weaviate.schema - weaviate.schema.properties - weaviate.batch weaviate.backup - weaviate.classification - weaviate.contextionary - weaviate.data - weaviate.data.references - weaviate.data.replication - weaviate.gql weaviate.cluster + weaviate.collections + weaviate.connect + weaviate.gql weaviate.proto weaviate.proto.v1 platforms = any include_package_data = True install_requires = - requests>=2.30.0,<3.0.0 httpx>=0.25.0,<=0.27.0 validators==0.34.0 authlib>=1.2.1,<1.3.2 diff --git a/test/batch/__init__.py b/test/batch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/batch/test_requests.py b/test/batch/test_requests.py deleted file mode 100644 index 446e027b8..000000000 --- a/test/batch/test_requests.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Test the 'weaviate.batch.requests' functions/classes. -""" - -import unittest -from unittest.mock import patch - -from test.util import check_error_message -from weaviate.batch.requests import ReferenceBatchRequest, ObjectsBatchRequest - - -class TestBatchReferences(unittest.TestCase): - """ - Test the `ReferenceBatchRequest` class. - """ - - @patch("weaviate.batch.requests.get_valid_uuid", side_effect=lambda x: x) - def test_add_and_get_request_body(self, mock_get_valid_uuid): - """ - Test the all the ReferenceBatchRequest's methods. - """ - - batch = ReferenceBatchRequest() - - ####################################################################### - # invalid calls - ####################################################################### - ## error messages - type_error_message_1 = "'from_object_class_name' argument must be of type str" - type_error_message_2 = "'from_property_name' argument must be of type str" - type_error_message_3 = "'to_object_class_name' argument must be of type str" - - with self.assertRaises(TypeError) as error: - batch.add(10, "some_str", "some_str", "some_str") - check_error_message(self, error, type_error_message_1) - - with self.assertRaises(TypeError) as error: - batch.add("some_str", "some_str", True, "some_str") - check_error_message(self, error, type_error_message_2) - - with self.assertRaises(TypeError) as error: - batch.add("some_str", "some_str", "some_str", "some_uuid", 1.0) - check_error_message(self, error, type_error_message_3) - - ####################################################################### - # valid calls - ####################################################################### - batch = ReferenceBatchRequest() - - ####################################################################### - # test initial values - self.assertEqual(len(batch), 0) - self.assertTrue(batch.is_empty()) - self.assertEqual(mock_get_valid_uuid.call_count, 0) - - ####################################################################### - # add first reference - batch.add("Alpha", "UUID_1", "a", "UUID_2") - self.assertEqual(len(batch), 1) - self.assertFalse(batch.is_empty()) - expected_item_1 = { - "from": "weaviate://localhost/Alpha/UUID_1/a", - "to": "weaviate://localhost/UUID_2", - } - self.assertEqual(batch.get_request_body(), [expected_item_1]) - self.assertEqual(mock_get_valid_uuid.call_count, 2) - - ####################################################################### - # add second reference - batch.add("Beta", "UUID_2", "b", "UUID_3") - self.assertEqual(len(batch), 2) - self.assertFalse(batch.is_empty()) - expected_item_2 = { - "from": "weaviate://localhost/Beta/UUID_2/b", - "to": "weaviate://localhost/UUID_3", - } - self.assertEqual(batch.get_request_body(), [expected_item_1, expected_item_2]) - self.assertEqual(mock_get_valid_uuid.call_count, 4) - - ####################################################################### - # pop first reference - self.assertEqual(batch.pop(0), expected_item_1) - self.assertEqual(len(batch), 1) - - ####################################################################### - # add one reference and pop it pop last reference - batch.add("Beta", "UUID_3", "b", "UUID_4") - expected_item_3 = { - "from": "weaviate://localhost/Beta/UUID_3/b", - "to": "weaviate://localhost/UUID_4", - } - self.assertEqual(len(batch), 2) - self.assertFalse(batch.is_empty()) - self.assertEqual(batch.pop(), expected_item_3) - self.assertEqual(len(batch), 1) - self.assertFalse(batch.is_empty()) - - ####################################################################### - # add 2 more references and then empty the batch - batch.add("Beta", "UUID_4", "b", "UUID_5") - batch.add("Beta", "UUID_5", "b", "UUID_4") - self.assertEqual(len(batch), 3) - self.assertFalse(batch.is_empty()) - batch.empty() - self.assertEqual(len(batch), 0) - self.assertTrue(batch.is_empty()) - - -class TestBatchObjects(unittest.TestCase): - """ - Test the `ObjectsBatchRequest` class. - """ - - @patch("weaviate.batch.requests.uuid4", side_effect=lambda: "d087b7c6a1155c898cb2f25bdeb9bf92") - @patch("weaviate.batch.requests.get_vector", side_effect=lambda x: x) - @patch("weaviate.batch.requests.get_valid_uuid", side_effect=lambda x: x) - def test_add_and_get_request_body(self, mock_get_valid_uuid, mock_get_vector, mock_uuid4): - """ - Test the all the ObjectsBatchRequest's methods. - """ - - batch = ObjectsBatchRequest() - ####################################################################### - # invalid calls - ####################################################################### - ## error messages - data_type_error_message = "Object must be of type dict" - class_type_error_message = "Class name must be of type str" - - ####################################################################### - # wrong data_object - with self.assertRaises(TypeError) as error: - batch.add( - data_object=None, - class_name="Class", - ) - check_error_message(self, error, data_type_error_message) - - with self.assertRaises(TypeError) as error: - batch.add( - data_object=224345, - class_name="Class", - ) - check_error_message(self, error, data_type_error_message) - - ####################################################################### - # wrong class_name - with self.assertRaises(TypeError) as error: - batch.add( - data_object={"name": "Optimus Prime"}, - class_name=None, - ) - check_error_message(self, error, class_type_error_message) - - with self.assertRaises(TypeError) as error: - batch.add( - data_object={"name": "Optimus Prime"}, - class_name=["Transformer"], - ) - check_error_message(self, error, class_type_error_message) - - ####################################################################### - # valid calls - ####################################################################### - ## test initial values - self.assertEqual(len(batch), 0) - self.assertTrue(batch.is_empty()) - self.assertEqual(mock_get_valid_uuid.call_count, 0) - self.assertEqual(mock_get_vector.call_count, 0) - expected_return = {"fields": ["ALL"], "objects": []} - self.assertEqual(batch.get_request_body(), expected_return) - - ####################################################################### - # add an object without 'uuid' and 'vector' - obj = {"class": "Philosopher", "properties": {"name": "Socrates"}} - expected_return["objects"].append( - { - "class": "Philosopher", - "properties": {"name": "Socrates"}, - "id": "d087b7c6a1155c898cb2f25bdeb9bf92", - } - ) - res_uuid = batch.add( - data_object=obj["properties"], - class_name=obj["class"], - ) - self.assertEqual(len(batch), 1) - self.assertFalse(batch.is_empty()) - self.assertEqual(mock_get_valid_uuid.call_count, 1) - self.assertEqual(mock_get_vector.call_count, 0) - self.assertEqual(batch.get_request_body(), expected_return) - self.assertEqual(res_uuid, "d087b7c6a1155c898cb2f25bdeb9bf92") - ## change obj and check if batch does not reflect this change - obj["properties"]["name"] = "Test" - self.assertEqual(batch.get_request_body(), expected_return) - - ####################################################################### - # add an object without 'vector' - obj = { - "class": "Chemist", - "properties": {"name": "Marie Curie"}, - "id": "d087b7c6-a115-5c89-8cb2-f25bdeb9bf93", - } - expected_return["objects"].append( - { - "class": "Chemist", - "properties": {"name": "Marie Curie"}, - "id": "d087b7c6-a115-5c89-8cb2-f25bdeb9bf93", - } - ) - res_uuid = batch.add( - data_object=obj["properties"], - class_name=obj["class"], - uuid=obj["id"], - ) - self.assertEqual(len(batch), 2) - self.assertFalse(batch.is_empty()) - self.assertEqual(mock_get_valid_uuid.call_count, 2) - self.assertEqual(mock_get_vector.call_count, 0) - self.assertEqual(batch.get_request_body(), expected_return) - self.assertEqual(res_uuid, "d087b7c6-a115-5c89-8cb2-f25bdeb9bf93") - ## change obj and check if batch does not reflect this change - obj["properties"]["name"] = "Test" - self.assertEqual(batch.get_request_body(), expected_return) - - ####################################################################### - # add an object without 'uuid' - obj = {"class": "Writer", "properties": {"name": "Stephen King"}, "vector": [1, 2, 3]} - expected_return["objects"].append( - { - "class": "Writer", - "properties": {"name": "Stephen King"}, - "vector": [1, 2, 3], - "id": "d087b7c6a1155c898cb2f25bdeb9bf92", - } - ) - res_uuid = batch.add( - data_object=obj["properties"], - class_name=obj["class"], - vector=obj["vector"], - ) - self.assertEqual(len(batch), 3) - self.assertFalse(batch.is_empty()) - self.assertEqual(mock_get_valid_uuid.call_count, 3) - self.assertEqual(mock_get_vector.call_count, 1) - self.assertEqual(batch.get_request_body(), expected_return) - self.assertEqual(res_uuid, "d087b7c6a1155c898cb2f25bdeb9bf92") - ## change obj and check if batch does not reflect this change - obj["properties"]["name"] = "Test" - self.assertEqual(batch.get_request_body(), expected_return) - - ####################################################################### - # add an object with all arguments - obj = { - "class": "Inventor", - "properties": {"name": "Nikola Tesla"}, - "id": "d087b7c6-a115-5c89-8cb2-f25bdeb9bf95", - "vector": [1, 2, 3], - } - expected_return["objects"].append( - { - "class": "Inventor", - "properties": {"name": "Nikola Tesla"}, - "id": "d087b7c6-a115-5c89-8cb2-f25bdeb9bf95", - "vector": [1, 2, 3], - } - ) - res_uuid = batch.add( - data_object=obj["properties"], - class_name=obj["class"], - uuid=obj["id"], - vector=obj["vector"], - ) - self.assertEqual(len(batch), 4) - self.assertFalse(batch.is_empty()) - self.assertEqual(mock_get_valid_uuid.call_count, 4) - self.assertEqual(mock_get_vector.call_count, 2) - self.assertEqual(batch.get_request_body(), expected_return) - self.assertEqual(res_uuid, "d087b7c6-a115-5c89-8cb2-f25bdeb9bf95") - ## change obj and check if batch does not reflect this change - obj["properties"]["name"] = "Test" - self.assertEqual(batch.get_request_body(), expected_return) - - ####################################################################### - # pop one object with index=1 - - self.assertEqual(batch.pop(0), expected_return["objects"][0]) - self.assertEqual(len(batch), 3) - self.assertFalse(batch.is_empty()) - expected_return["objects"] = expected_return["objects"][1:] - - ####################################################################### - # pop last object - - self.assertEqual(batch.pop(), expected_return["objects"][-1]) - self.assertEqual(len(batch), 2) - self.assertFalse(batch.is_empty()) - expected_return["objects"] = expected_return["objects"][:-1] - - ####################################################################### - # empty the batch request - - self.assertFalse(batch.is_empty()) - batch.empty() - self.assertEqual(len(batch), 0) - self.assertTrue(batch.is_empty()) diff --git a/test/classification/__init__.py b/test/classification/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/classification/test_classification.py b/test/classification/test_classification.py deleted file mode 100644 index 32ce40714..000000000 --- a/test/classification/test_classification.py +++ /dev/null @@ -1,531 +0,0 @@ -import unittest -from unittest.mock import patch, Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.classification.classification import Classification, ConfigBuilder -from weaviate.exceptions import UnexpectedStatusCodeException - - -class TestClassification(unittest.TestCase): - def test_schedule(self): - """ - Test the `schedule` method. - """ - - self.assertIsInstance(Classification(None).schedule(), ConfigBuilder) - - def test_get(self): - """ - Test the `get` method. - """ - - # error messages - uuid_type_error = lambda dt: f"'uuid' must be of type str or uuid.UUID, but was: {dt}" - value_error = "Not valid 'uuid' or 'uuid' can not be extracted from value" - requests_error_message = "Classification status could not be retrieved." - unexpected_error_message = "Get classification status" - - # invalid calls - with self.assertRaises(TypeError) as error: - Classification(None).get(123) - check_error_message(self, error, uuid_type_error(int)) - - with self.assertRaises(ValueError) as error: - Classification(None).get("123") - check_error_message(self, error, value_error) - - mock_conn = mock_connection_func("get", side_effect=RequestsConnectionError("Test!")) - with self.assertRaises(RequestsConnectionError) as error: - Classification(mock_conn).get("d087b7c6-a115-5c89-8cb2-f25bdeb9bf92") - check_error_message(self, error, requests_error_message) - - mock_conn = mock_connection_func("get", status_code=404) - with self.assertRaises(UnexpectedStatusCodeException) as error: - Classification(mock_conn).get("d087b7c6-a115-5c89-8cb2-f25bdeb9bf92") - check_startswith_error_message(self, error, unexpected_error_message) - - # valid calls - mock_conn = mock_connection_func("get", return_json={"OK": "GOOD"}, status_code=200) - result = Classification(mock_conn).get("d087b7c6-a115-5c89-8cb2-f25bdeb9bf92") - self.assertEqual(result, {"OK": "GOOD"}) - - @patch("weaviate.classification.classification.Classification._check_status") - def test_is_complete(self, mock_check_status): - """ - Test the `is_complete` method. - """ - - mock_check_status.return_value = "OK!" - result = Classification(None).is_complete("Test!") - self.assertEqual(result, "OK!") - mock_check_status.assert_called_with("Test!", "completed") - - @patch("weaviate.classification.classification.Classification._check_status") - def test_is_failed(self, mock_check_status): - """ - Test the `is_failed` method. - """ - - mock_check_status.return_value = "OK!" - result = Classification(None).is_failed("Test!") - self.assertEqual(result, "OK!") - mock_check_status.assert_called_with("Test!", "failed") - - @patch("weaviate.classification.classification.Classification._check_status") - def test_is_running(self, mock_check_status): - """ - Test the `is_running` method. - """ - - mock_check_status.return_value = "OK!" - result = Classification(None).is_running("Test!") - self.assertEqual(result, "OK!") - mock_check_status.assert_called_with("Test!", "running") - - @patch("weaviate.classification.classification.Classification.get") - def test__check_status(self, mock_get): - """ - Test the `_check_status` method. - """ - - mock_get.return_value = {"status": "failed"} - - result = Classification(None)._check_status("uuid", "running") - self.assertFalse(result) - - result = Classification(None)._check_status("uuid", "failed") - self.assertTrue(result) - - mock_get.side_effect = RequestsConnectionError("Test!") - result = Classification(None)._check_status("uuid", "running") - self.assertFalse(result) - - -class TestConfigBuilder(unittest.TestCase): - def test_with_type(self): - """ - Test the `with_type` method. - """ - - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_type("test_type") - - self.assertEqual(config._config, {"type": "test_type"}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_k(self): - """ - Test the `with_k` method. - """ - - # without `with_settings` called - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_k(4) - - self.assertEqual(config._config, {"settings": {"k": 4}}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - # with `with_settings` called - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_k(5).with_settings({"test": "OK!"}) - - self.assertEqual(config._config, {"settings": {"k": 5, "test": "OK!"}}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_class_name(self): - """ - Test the `with_class_name` method. - """ - - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - # Correct class name format (capitalized) - result = config.with_class_name("TestClass") - - self.assertEqual(config._config, {"class": "TestClass"}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - # Incorrect class name format (capitalized), should be capitalized by the client - result = config.with_class_name("testClass") - - self.assertEqual(config._config, {"class": "TestClass"}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_classify_properties(self): - """ - Test the `with_classify_properties` method. - """ - - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_classify_properties(["test1", "test2"]) - - self.assertEqual(config._config, {"classifyProperties": ["test1", "test2"]}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_based_on_properties(self): - """ - Test the `with_based_on_properties` method. - """ - - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_based_on_properties(["test1", "test2"]) - - self.assertEqual(config._config, {"basedOnProperties": ["test1", "test2"]}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_source_where_filter(self): - """ - Test the `with_source_where_filter` method. - """ - - # without other filters set before - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_source_where_filter({"test": "OK!"}) - - self.assertEqual(config._config, {"filters": {"sourceWhere": {"test": "OK!"}}}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - # with other filters set before - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_training_set_where_filter({"test": "OK!"}).with_source_where_filter( - {"test": "OK!"} - ) - - self.assertEqual( - config._config, - {"filters": {"sourceWhere": {"test": "OK!"}, "trainingSetWhere": {"test": "OK!"}}}, - ) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_training_set_where_filter(self): - """ - Test the `with_training_set_where_filter` method. - """ - - # without other filters set before - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_training_set_where_filter({"test": "OK!"}) - - self.assertEqual(config._config, {"filters": {"trainingSetWhere": {"test": "OK!"}}}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - # with other filters set before - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_target_where_filter({"test": "OK!"}).with_training_set_where_filter( - {"test": "OK!"} - ) - - self.assertEqual( - config._config, - {"filters": {"trainingSetWhere": {"test": "OK!"}, "targetWhere": {"test": "OK!"}}}, - ) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_target_where_filter(self): - """ - Test the `with_target_where_filter` method. - """ - - # without other filters set before - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_target_where_filter({"test": "OK!"}) - - self.assertEqual(config._config, {"filters": {"targetWhere": {"test": "OK!"}}}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - # with other filters set before - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_source_where_filter({"test": "OK!"}).with_target_where_filter( - {"test": "OK!"} - ) - - self.assertEqual( - config._config, - {"filters": {"targetWhere": {"test": "OK!"}, "sourceWhere": {"test": "OK!"}}}, - ) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_wait_for_completion(self): - """ - Test the `with_wait_for_completion` method. - """ - - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_wait_for_completion() - - self.assertEqual(config._config, {}) - self.assertTrue(config._wait_for_completion) - self.assertIs(result, config) - - def test_with_settings(self): - """ - Test the `with_settings` method. - """ - - # without `with_k` called - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_settings({"test": "OK!"}) - - self.assertEqual(config._config, {"settings": {"test": "OK!"}}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - # with `with_k` called - config = ConfigBuilder(None, None) - self.assertEqual(config._config, {}) - self.assertFalse(config._wait_for_completion) - - result = config.with_settings({"test": "OK!"}).with_k(7) - - self.assertEqual(config._config, {"settings": {"k": 7, "test": "OK!"}}) - self.assertFalse(config._wait_for_completion) - self.assertIs(result, config) - - def test__validate_config(self): - """ - Test the `_validate_config` method. - """ - - # error messages - field_error_message = lambda f: f"{f} is not set for this classification" - settings_error_message = '"settings" should be of type dict' - k_error_message = "k is not set for this classification" - - # test required fields without "classifyProperties" - config = ( - ConfigBuilder(None, None) - .with_type("Test!") - .with_class_name("Test!") - .with_based_on_properties(["Test!"]) - ) - with self.assertRaises(ValueError) as error: - config._validate_config() - check_error_message(self, error, field_error_message("classifyProperties")) - - # test required fields without "basedOnProperties" - config = ( - ConfigBuilder(None, None) - .with_type("Test!") - .with_class_name("Test!") - .with_classify_properties(["Test!"]) - ) - with self.assertRaises(ValueError) as error: - config._validate_config() - check_error_message(self, error, field_error_message("basedOnProperties")) - - # test required fields without "class" - config = ( - ConfigBuilder(None, None) - .with_type("Test!") - .with_based_on_properties(["Test!"]) - .with_classify_properties(["Test!"]) - ) - with self.assertRaises(ValueError) as error: - config._validate_config() - check_error_message(self, error, field_error_message("class")) - - # test required fields without "type" - config = ( - ConfigBuilder(None, None) - .with_based_on_properties(["Test!"]) - .with_class_name("Test!") - .with_classify_properties(["Test!"]) - ) - with self.assertRaises(ValueError) as error: - config._validate_config() - check_error_message(self, error, field_error_message("type")) - - # test required fields with all required - config = ( - ConfigBuilder(None, None) - .with_based_on_properties(["Test!"]) - .with_type("Test!") - .with_class_name("Test!") - .with_classify_properties(["Test!"]) - ) - config._validate_config() - - # test settings - config = ( - ConfigBuilder(None, None) - .with_based_on_properties(["Test!"]) - .with_type("Test!") - .with_class_name("Test!") - .with_classify_properties(["Test!"]) - .with_settings(["Test!"]) - ) - with self.assertRaises(TypeError) as error: - config._validate_config() - check_error_message(self, error, settings_error_message) - - # test knn without k - config = ( - ConfigBuilder(None, None) - .with_based_on_properties(["Test!"]) - .with_type("knn") - .with_class_name("Test!") - .with_classify_properties(["Test!"]) - ) - with self.assertRaises(ValueError) as error: - config._validate_config() - check_error_message(self, error, k_error_message) - - # test knn with k - config = ( - ConfigBuilder(None, None) - .with_based_on_properties(["Test!"]) - .with_type("knn") - .with_class_name("Test!") - .with_classify_properties(["Test!"]) - .with_k(4) - ) - config._validate_config() - - def test__start(self): - """ - Test the `_start` method. - """ - - # error messages - requests_error_message = "Classification may not started." - unexpected_error_message = "Start classification" - - # invalid calls - mock_conn = mock_connection_func("post", side_effect=RequestsConnectionError("Test!")) - config = ConfigBuilder(mock_conn, None) - with self.assertRaises(RequestsConnectionError) as error: - config._start() - check_error_message(self, error, requests_error_message) - mock_conn.post.assert_called_with(path="/classifications", weaviate_object={}) - - mock_conn = mock_connection_func("post", status_code=200) - config = ConfigBuilder(mock_conn, None).with_class_name("Test!") - with self.assertRaises(UnexpectedStatusCodeException) as error: - config._start() - check_startswith_error_message(self, error, unexpected_error_message) - mock_conn.post.assert_called_with( - path="/classifications", weaviate_object={"class": "Test!"} - ) - - # valid calls - mock_conn = mock_connection_func("post", status_code=201, return_json="OK!") - config = ConfigBuilder(mock_conn, None).with_class_name("TestClass").with_type("TestType") - self.assertEqual(config._start(), "OK!") - mock_conn.post.assert_called_with( - path="/classifications", weaviate_object={"class": "TestClass", "type": "TestType"} - ) - - @patch("weaviate.classification.config_builder.ConfigBuilder._start") - @patch( - "weaviate.classification.config_builder.ConfigBuilder._validate_config", return_value=None - ) - def test_do(self, mock_validate_config, mock_start): - """ - Test the `do` method. - """ - - mock_start.return_value = {"status": "test"} - config = ConfigBuilder(None, None) - self.assertEqual(config.do(), {"status": "test"}) - - mock_start.return_value = {"status": "test", "id": "test_id"} - mock_classification = Mock() # mock self._classification instance - - def mock_waiting(test): - if mock_waiting.called: - return False - mock_waiting.called = True - return True - - mock_waiting.called = False # initialize static variable - mock_classification.is_running.side_effect = mock_waiting - mock_classification.get.return_value = "test" - config = ConfigBuilder(None, mock_classification).with_wait_for_completion() - self.assertEqual(config.do(), "test") - - def test_integration_config(self): - """ - Test all `with_` methods together that change the configuration. - """ - - config = ( - ConfigBuilder(None, None) - .with_type("test_type") - .with_k(4) - .with_class_name("TestClass") - .with_classify_properties(["Test1!"]) - .with_based_on_properties(["Test2!"]) - .with_source_where_filter({"test": "OK1!"}) - .with_training_set_where_filter({"test": "OK2!"}) - .with_target_where_filter({"test": "OK3!"}) - .with_settings({"additional": "test_settings"}) - ) - expected_config = { - "type": "test_type", - "settings": {"k": 4, "additional": "test_settings"}, - "class": "TestClass", - "classifyProperties": ["Test1!"], - "basedOnProperties": ["Test2!"], - "filters": { - "sourceWhere": {"test": "OK1!"}, - "trainingSetWhere": {"test": "OK2!"}, - "targetWhere": {"test": "OK3!"}, - }, - } - self.assertEqual(config._config, expected_config) diff --git a/test/cluster/__init__.py b/test/cluster/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/cluster/test_cluster.py b/test/cluster/test_cluster.py deleted file mode 100644 index cda56e9ad..000000000 --- a/test/cluster/test_cluster.py +++ /dev/null @@ -1,53 +0,0 @@ -import unittest - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.cluster.cluster import Cluster -from weaviate.exceptions import ( - UnexpectedStatusCodeException, - EmptyResponseException, -) - - -class TestCluster(unittest.TestCase): - def test_get_nodes_status(self): - # error messages - - unexpected_err_msg = "Nodes status" - empty_response_err_msg = "Nodes status response returned empty" - connection_err_msg = "Get nodes status failed due to connection error" - - # expected failure - mock_conn = mock_connection_func("get", status_code=500) - with self.assertRaises(UnexpectedStatusCodeException) as error: - Cluster(mock_conn).get_nodes_status() - check_startswith_error_message(self, error, unexpected_err_msg) - - mock_conn = mock_connection_func("get", status_code=200, return_json={"nodes": []}) - with self.assertRaises(EmptyResponseException) as error: - Cluster(mock_conn).get_nodes_status() - check_error_message(self, error, empty_response_err_msg) - - mock_conn = mock_connection_func("get", side_effect=RequestsConnectionError) - with self.assertRaises(RequestsConnectionError) as error: - Cluster(mock_conn).get_nodes_status() - check_error_message(self, error, connection_err_msg) - - # expected success - expected_resp = { - "nodes": [ - { - "gitHash": "abcd123", - "name": "node1", - "shards": [{"class": "SomeClass", "name": "1qa2ws3ed", "objectCount": 100}], - "stats": {"objectCount": 100, "shardCount": 1}, - "status": "", - "version": "x.x.x", - } - ] - } - mock_conn = mock_connection_func("get", status_code=200, return_json=expected_resp) - result = Cluster(mock_conn).get_nodes_status() - self.assertListEqual(result, expected_resp.get("nodes")) - mock_conn.get.assert_called_with(path="/nodes") diff --git a/test/connection/__init__.py b/test/connection/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/connection/test_connection.py b/test/connection/test_connection.py deleted file mode 100644 index d44b390ca..000000000 --- a/test/connection/test_connection.py +++ /dev/null @@ -1,167 +0,0 @@ -import unittest -from unittest.mock import patch - -from test.util import check_error_message -from weaviate.connect.base import _get_proxies -from weaviate.connect.v3 import Connection -from weaviate.util import _get_valid_timeout_config - - -class TestConnection(unittest.TestCase): - def check_connection_attributes( - self, - connection: Connection, - url="test_url", - timeout_config=(2, 20), - oidc_auth_flow=False, - headers=None, - ): - """ - Check the attributes of the connection value. Assign 'skip' to - an attribute to skip testing. The attributes have the default constructor values. - """ - - if headers is None: - headers = {"content-type": "application/json"} - if url != "skip": - self.assertEqual(connection.url, url) - if timeout_config != "skip": - self.assertEqual(connection.timeout_config, timeout_config) - if oidc_auth_flow != "skip": - if oidc_auth_flow is True: - self.assertIsNotNone(connection._auth) - else: - self.assertIsNone(connection._auth) - if headers != "skip": - self.assertEqual(connection._headers, headers) - - @patch("weaviate.connect.base.datetime") - def test_get_epoch_time(self, mock_datetime): - """ - Test the `get_epoch_time` function. - """ - - import datetime - from weaviate.connect.base import _get_epoch_time - - zero_epoch = datetime.datetime.fromtimestamp(0) - mock_datetime.datetime.utcnow.return_value = zero_epoch - self.assertEqual(_get_epoch_time(), 0) - - epoch = datetime.datetime.fromtimestamp(110.56) - mock_datetime.datetime.utcnow.return_value = epoch - self.assertEqual(_get_epoch_time(), 111) - - epoch = datetime.datetime.fromtimestamp(110.46) - mock_datetime.datetime.utcnow.return_value = epoch - self.assertEqual(_get_epoch_time(), 110) - - @patch("weaviate.connect.base.os") - def test_get_proxies(self, os_mock): - """ - Test the `_get_proxies` function. - """ - - error_msg = lambda dt: ( - "If 'proxies' is not None, it must be of type dict, str, or wvc.init.Proxies. " - f"Given type: {dt}." - ) - with self.assertRaises(TypeError) as error: - proxies = _get_proxies([], False) - check_error_message(self, error, error_msg(list)) - - proxies = _get_proxies({}, False) - self.assertEqual(proxies, {}) - - proxies = _get_proxies({"test": True}, False) - self.assertEqual(proxies, {"test": True}) - - proxies = _get_proxies({"test": True}, True) - self.assertEqual(proxies, {"test": True}) - - proxies = _get_proxies("test", True) - self.assertEqual(proxies, {"http": "test", "https": "test", "grpc": "test"}) - - os_mock.environ.get.return_value = None - proxies = _get_proxies(None, True) - self.assertEqual(proxies, {}) - - os_mock.environ.get.return_value = "test" - proxies = _get_proxies(None, True) - self.assertEqual(proxies, {"http": "test", "https": "test", "grpc": "test"}) - - def test__get_valid_timeout_config(self): - """ - Test the `_get_valid_timeout_config` function. - """ - - # incalid calls - negative_num_error_message = "'timeout_config' cannot be non-positive number/s!" - type_error_message = "'timeout_config' should be a (or tuple of) positive number/s!" - value_error_message = "'timeout_config' must be of length 2!" - value_types_error_message = "'timeout_config' must be tuple of numbers" - - ## wrong type - with self.assertRaises(TypeError) as error: - _get_valid_timeout_config(None) - check_error_message(self, error, type_error_message) - - with self.assertRaises(TypeError) as error: - _get_valid_timeout_config(True) - check_error_message(self, error, type_error_message) - - with self.assertRaises(TypeError) as error: - _get_valid_timeout_config("(2, 13)") - check_error_message(self, error, type_error_message) - - ## wrong tuple length 3 - with self.assertRaises(ValueError) as error: - _get_valid_timeout_config((1, 2, 3)) - check_error_message(self, error, value_error_message) - - ## wrong value types - with self.assertRaises(TypeError) as error: - _get_valid_timeout_config((None, None)) - check_error_message(self, error, value_types_error_message) - - with self.assertRaises(TypeError) as error: - _get_valid_timeout_config(("1", 10)) - check_error_message(self, error, value_types_error_message) - - with self.assertRaises(TypeError) as error: - _get_valid_timeout_config(("1", "10")) - check_error_message(self, error, value_types_error_message) - - with self.assertRaises(TypeError) as error: - _get_valid_timeout_config((True, False)) - check_error_message(self, error, value_types_error_message) - - ## non-positive numbers - with self.assertRaises(ValueError) as error: - _get_valid_timeout_config(0) - check_error_message(self, error, negative_num_error_message) - - with self.assertRaises(ValueError) as error: - _get_valid_timeout_config(-1) - check_error_message(self, error, negative_num_error_message) - - with self.assertRaises(ValueError) as error: - _get_valid_timeout_config(-4.134) - check_error_message(self, error, negative_num_error_message) - - with self.assertRaises(ValueError) as error: - _get_valid_timeout_config((-3.5, 1.5)) - check_error_message(self, error, negative_num_error_message) - - with self.assertRaises(ValueError) as error: - _get_valid_timeout_config((3, -1.5)) - check_error_message(self, error, negative_num_error_message) - - with self.assertRaises(ValueError) as error: - _get_valid_timeout_config((0, 0)) - check_error_message(self, error, negative_num_error_message) - - # valid calls - self.assertEqual(_get_valid_timeout_config((2, 20)), (2, 20)) - self.assertEqual(_get_valid_timeout_config((3.5, 2.34)), (3.5, 2.34)) - self.assertEqual(_get_valid_timeout_config(4.32), (4.32, 4.32)) diff --git a/test/contextionary/__init__.py b/test/contextionary/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/contextionary/test_text2vec_contextionary.py b/test/contextionary/test_text2vec_contextionary.py deleted file mode 100644 index 87e30764d..000000000 --- a/test/contextionary/test_text2vec_contextionary.py +++ /dev/null @@ -1,119 +0,0 @@ -import unittest -from unittest.mock import Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.contextionary import Contextionary -from weaviate.exceptions import UnexpectedStatusCodeException - - -class TestText2VecContextionary(unittest.TestCase): - def test_extend(self): - """ - Test `extend` method. - """ - - contextionary = Contextionary(Mock()) - - some_concept = { - "concept": "lsd", - "definition": "In probability and statistics, the logarithmic series distribution is a discrete probability distribution derived from the Maclaurin series expansion", - } - - # error messages - concept_type_error_message = "Concept must be string" - definition_type_error_message = "Definition must be string" - weight_type_error_message = "Weight must be float" - weight_value_error_message = "Weight out of limits 0.0 <= weight <= 1.0" - requests_error_message = "text2vec-contextionary could not be extended." - unexpected_error_message = "Extend text2vec-contextionary" - - ## test exceptions - with self.assertRaises(TypeError) as error: - contextionary.extend(concept=None, definition=some_concept["definition"], weight=1.0) - check_error_message(self, error, concept_type_error_message) - - with self.assertRaises(TypeError) as error: - contextionary.extend(concept=some_concept["concept"], definition=None, weight=1.0) - check_error_message(self, error, definition_type_error_message) - - with self.assertRaises(TypeError) as error: - contextionary.extend(**some_concept, weight=None) - check_error_message(self, error, weight_type_error_message) - - with self.assertRaises(ValueError) as error: - contextionary.extend(**some_concept, weight=1.1) - check_error_message(self, error, weight_value_error_message) - - with self.assertRaises(ValueError) as error: - contextionary.extend(**some_concept, weight=-1.0) - check_error_message(self, error, weight_value_error_message) - - ## test UnexpectedStatusCodeException - contextionary = Contextionary(mock_connection_func("post", status_code=404)) - with self.assertRaises(UnexpectedStatusCodeException) as error: - contextionary.extend(**some_concept) - check_startswith_error_message(self, error, unexpected_error_message) - - ## test requests error - contextionary = Contextionary( - mock_connection_func("post", side_effect=RequestsConnectionError("Test!")) - ) - with self.assertRaises(RequestsConnectionError) as error: - contextionary.extend(**some_concept) - check_error_message(self, error, requests_error_message) - - ## test valid call without specifying 'weight' - some_concept["weight"] = 1.0 - connection_mock = mock_connection_func("post", status_code=200) - contextionary = Contextionary(connection_mock) - contextionary.extend(**some_concept) - connection_mock.post.assert_called_with( - path="/modules/text2vec-contextionary/extensions", - weaviate_object=some_concept, - ) - - ## test valid call with specifying 'weight as error' - connection_mock = mock_connection_func("post", status_code=200) - contextionary = Contextionary(connection_mock) - # add weight to 'some_concept' - some_concept["weight"] = 0.1234 - contextionary.extend(**some_concept) - connection_mock.post.assert_called_with( - path="/modules/text2vec-contextionary/extensions", - weaviate_object=some_concept, - ) - - def test_get_concept_vector(self): - """ - Test `get_concept_vector` method. - """ - - # test valid call - connection_mock = mock_connection_func("get", return_json={"A": "B"}) - contextionary = Contextionary(connection_mock) - self.assertEqual("B", contextionary.get_concept_vector("sauce")["A"]) - connection_mock.get.assert_called_with( - path="/modules/text2vec-contextionary/concepts/sauce", - ) - - # test exceptions - - # error messages - requests_error_message = "text2vec-contextionary vector was not retrieved." - unexpected_exception_error_message = "text2vec-contextionary vector" - - ## test UnexpectedStatusCodeException - contextionary = Contextionary(mock_connection_func("get", status_code=404)) - with self.assertRaises(UnexpectedStatusCodeException) as error: - contextionary.get_concept_vector("Palantir") - check_startswith_error_message(self, error, unexpected_exception_error_message) - - ## test requests error - contextionary = Contextionary( - mock_connection_func("get", side_effect=RequestsConnectionError("Test!")) - ) - with self.assertRaises(RequestsConnectionError) as error: - contextionary.get_concept_vector("Palantir") - check_error_message(self, error, requests_error_message) diff --git a/test/data/__init__.py b/test/data/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/data/references/__init__.py b/test/data/references/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/data/references/test_crud_references.py b/test/data/references/test_crud_references.py deleted file mode 100644 index 2a1a8ca2a..000000000 --- a/test/data/references/test_crud_references.py +++ /dev/null @@ -1,311 +0,0 @@ -import unittest -from unittest.mock import Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.data.references import Reference -from weaviate.data.replication import ConsistencyLevel -from weaviate.exceptions import UnexpectedStatusCodeException - - -class TestReference(unittest.TestCase): - def setUp(self): - self.uuid_1 = "b36268d4-a6b5-5274-985f-45f13ce0c642" - self.uuid_2 = "a36268d4-a6b5-5274-985f-45f13ce0c642" - self.uuid_error_message = f"'uuid' must be of type str or uuid.UUID, but was: {int}" - self.valid_uuid_error_message = "Not valid 'uuid' or 'uuid' can not be extracted from value" - self.name_error_message = ( - lambda p: f"'from_property_name' must be of type 'str'. Given type: {p}" - ) - - def test_delete(self): - """ - Test `delete` method`. - """ - - connection_mock = Mock() - connection_mock.server_version = "1.13.2" - reference = Reference(connection_mock) - - # error messages - unexpected_error_msg = "Delete property reference to object" - connection_error_msg = "Reference was not deleted." - - # invalid calls - with self.assertRaises(TypeError) as error: - reference.delete(1, "myProperty", self.uuid_2) - check_error_message(self, error, self.uuid_error_message) - - with self.assertRaises(TypeError) as error: - reference.delete(self.uuid_1, "myProperty", 2) - check_error_message(self, error, self.uuid_error_message) - - with self.assertRaises(TypeError) as error: - reference.delete(self.uuid_1, 3, self.uuid_2) - check_error_message(self, error, self.name_error_message(int)) - - with self.assertRaises(ValueError) as error: - reference.delete("str", "myProperty", self.uuid_2) - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.delete(self.uuid_1, "myProperty", "str") - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.delete(self.uuid_1, "myProperty", self.uuid_2, consistency_level=1) - - mock_obj = mock_connection_func("delete", status_code=200) - reference = Reference(mock_obj) - with self.assertRaises(UnexpectedStatusCodeException) as error: - reference.delete(self.uuid_1, "myProperty", self.uuid_2) - check_startswith_error_message(self, error, unexpected_error_msg) - - mock_obj = mock_connection_func("delete", side_effect=RequestsConnectionError("Test!")) - reference = Reference(mock_obj) - with self.assertRaises(RequestsConnectionError) as error: - reference.delete(self.uuid_1, "myProperty", self.uuid_2) - check_error_message(self, error, connection_error_msg) - - # test valid calls - connection_mock = mock_connection_func("delete", status_code=204) - reference = Reference(connection_mock) - - reference.delete(self.uuid_1, "myProperty", self.uuid_2) - - connection_mock.delete.assert_called_with( - path=f"/objects/{self.uuid_1}/references/myProperty", - weaviate_object={"beacon": f"weaviate://localhost/{self.uuid_2}"}, - params={}, - ) - - reference.delete( - self.uuid_1, - "hasItem", - f"http://localhost:8080/v1/objects/{self.uuid_2}", - consistency_level="ONE", - ) - - connection_mock.delete.assert_called_with( - path=f"/objects/{self.uuid_1}/references/hasItem", - weaviate_object={"beacon": f"weaviate://localhost/{self.uuid_2}"}, - params={"consistency_level": "ONE"}, - ) - - def test_add(self): - """ - Test the `add` method. - """ - - connection_mock = Mock() - connection_mock.server_version = "1.13.2" - reference = Reference(connection_mock) - - # error messages - unexpected_error_msg = "Add property reference to object" - connection_error_msg = "Reference was not added." - - # test exceptions - with self.assertRaises(TypeError) as error: - reference.add(1, "prop", self.uuid_1) - check_error_message(self, error, self.uuid_error_message) - - with self.assertRaises(TypeError) as error: - reference.add(self.uuid_1, 1, self.uuid_2) - check_error_message(self, error, self.name_error_message(int)) - - with self.assertRaises(TypeError) as error: - reference.add(self.uuid_1, "prop", 1) - check_error_message(self, error, self.uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.add("my UUID", "prop", self.uuid_2) - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.add(self.uuid_1, "prop", "my uuid") - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.add(self.uuid_1, "prop", self.uuid_2, consistency_level=1) - - with self.assertRaises(ValueError) as error: - reference.add( - f"http://localhost:8080/v1/objects/{self.uuid_1}", - "prop", - "http://localhost:8080/v1/objects/MY_UUID", - ) - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.add( - "http://localhost:8080/v1/objects/My-UUID", - "prop", - f"http://localhost:8080/v1/objects/{self.uuid_2}", - ) - check_error_message(self, error, self.valid_uuid_error_message) - - mock_obj = mock_connection_func("post", status_code=204) - reference = Reference(mock_obj) - with self.assertRaises(UnexpectedStatusCodeException) as error: - reference.add(self.uuid_1, "myProperty", self.uuid_2) - check_startswith_error_message(self, error, unexpected_error_msg) - - mock_obj = mock_connection_func("post", side_effect=RequestsConnectionError("Test!")) - reference = Reference(mock_obj) - with self.assertRaises(RequestsConnectionError) as error: - reference.add(self.uuid_1, "myProperty", self.uuid_2) - check_error_message(self, error, connection_error_msg) - - # test valid calls - connection_mock = mock_connection_func("post") - reference = Reference(connection_mock) - - # 1. Plain - reference.add( - "3250b0b8-eaf7-499b-ac68-9084c9c82d0f", - "hasItem", - "99725f35-f12a-4f36-a2e2-0d41501f4e0e", - ) - connection_mock.post.assert_called_with( - path="/objects/3250b0b8-eaf7-499b-ac68-9084c9c82d0f/references/hasItem", - weaviate_object={"beacon": "weaviate://localhost/99725f35-f12a-4f36-a2e2-0d41501f4e0e"}, - params={}, - ) - - # 2. using url - reference.add( - "http://localhost:8080/v1/objects/7591be77-5959-4386-9828-423fc5096e87", - "hasItem", - "http://localhost:8080/v1/objects/1cd80c11-29f0-453f-823c-21547b1511f0", - ) - connection_mock.post.assert_called_with( - path="/objects/7591be77-5959-4386-9828-423fc5096e87/references/hasItem", - weaviate_object={"beacon": "weaviate://localhost/1cd80c11-29f0-453f-823c-21547b1511f0"}, - params={}, - ) - - # 3. using weaviate url - reference.add( - "weaviate://localhost/f8def983-87e7-4e21-bf10-e32e2de3efcf", - "hasItem", - "weaviate://localhost/e40aaef5-d3e5-44f1-8ec4-3eafc8475078", - consistency_level="ALL", - ) - connection_mock.post.assert_called_with( - path="/objects/f8def983-87e7-4e21-bf10-e32e2de3efcf/references/hasItem", - weaviate_object={"beacon": "weaviate://localhost/e40aaef5-d3e5-44f1-8ec4-3eafc8475078"}, - params={"consistency_level": "ALL"}, - ) - - def test_update(self): - """ - Test the `update` method. - """ - - connection_mock = Mock() - connection_mock.server_version = "1.13.2" - reference = Reference(connection_mock) - - # error messages - unexpected_error_msg = "Update property reference to object" - connection_error_msg = "Reference was not updated." - - # test exceptions - with self.assertRaises(TypeError) as error: - reference.update(1, "prop", [self.uuid_1]) - check_error_message(self, error, self.uuid_error_message) - - with self.assertRaises(TypeError) as error: - reference.update(self.uuid_1, 1, [self.uuid_2]) - check_error_message(self, error, self.name_error_message(int)) - - with self.assertRaises(TypeError) as error: - reference.update(self.uuid_1, "prop", 1) - check_error_message(self, error, self.uuid_error_message) - - with self.assertRaises(TypeError) as error: - reference.update(self.uuid_1, "prop", [1]) - check_error_message(self, error, self.uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.update("my UUID", "prop", self.uuid_2) - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.update(self.uuid_1, "prop", "my uuid") - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.update(self.uuid_1, "prop", ["my uuid"]) - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.update(self.uuid_1, "prop", self.uuid_2, consistency_level=1) - check_error_message(self, error, "1 is not a valid ConsistencyLevel") - - with self.assertRaises(ValueError) as error: - reference.update( - f"http://localhost:8080/v1/objects/{self.uuid_1}", - "prop", - "http://localhost:8080/v1/objects/MY_UUID", - ) - check_error_message(self, error, self.valid_uuid_error_message) - - with self.assertRaises(ValueError) as error: - reference.update( - "http://localhost:8080/v1/objects/My-UUID", - "prop", - f"http://localhost:8080/v1/objects/{self.uuid_2}", - ) - check_error_message(self, error, self.valid_uuid_error_message) - - mock_obj = mock_connection_func("put", status_code=204) - reference = Reference(mock_obj) - with self.assertRaises(UnexpectedStatusCodeException) as error: - reference.update(self.uuid_1, "myProperty", self.uuid_2) - check_startswith_error_message(self, error, unexpected_error_msg) - - mock_obj = mock_connection_func("put", side_effect=RequestsConnectionError("Test!")) - reference = Reference(mock_obj) - with self.assertRaises(RequestsConnectionError) as error: - reference.update(self.uuid_1, "myProperty", self.uuid_2) - check_error_message(self, error, connection_error_msg) - - # test valid calls - connection_mock = mock_connection_func("put") - reference = Reference(connection_mock) - - reference.update( - "de998e81-fa66-440e-a1de-2a2013667e77", - "hasAwards", - "fc041624-4ddf-4b76-8e09-a5b0b9f9f832", - ) - connection_mock.put.assert_called_with( - path="/objects/de998e81-fa66-440e-a1de-2a2013667e77/references/hasAwards", - weaviate_object=[ - {"beacon": "weaviate://localhost/fc041624-4ddf-4b76-8e09-a5b0b9f9f832"} - ], - params={}, - ) - - reference.update( - "4e44db9b-7f9c-4cf4-a3a0-b57024eefed0", - "hasAwards", - [ - "17ee17bd-a09a-49ff-adeb-d242f25f390d", - "f8c25386-707c-40c0-b7b9-26cc0e9b2bd1", - "d671dc52-dce4-46e7-8731-b722f19420c8", - ], - consistency_level=ConsistencyLevel.QUORUM, - ) - connection_mock.put.assert_called_with( - path="/objects/4e44db9b-7f9c-4cf4-a3a0-b57024eefed0/references/hasAwards", - weaviate_object=[ - {"beacon": "weaviate://localhost/17ee17bd-a09a-49ff-adeb-d242f25f390d"}, - {"beacon": "weaviate://localhost/f8c25386-707c-40c0-b7b9-26cc0e9b2bd1"}, - {"beacon": "weaviate://localhost/d671dc52-dce4-46e7-8731-b722f19420c8"}, - ], - params={"consistency_level": "QUORUM"}, - ) diff --git a/test/data/test_crud_data.py b/test/data/test_crud_data.py deleted file mode 100644 index 95a65cfba..000000000 --- a/test/data/test_crud_data.py +++ /dev/null @@ -1,779 +0,0 @@ -import unittest -from unittest.mock import patch, Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.data import DataObject -from weaviate.data.replication import ConsistencyLevel -from weaviate.exceptions import ( - UnexpectedStatusCodeException, - ObjectAlreadyExistsException, -) - - -class TestDataObject(unittest.TestCase): - @patch("weaviate.data.crud_data._get_dict_from_object", side_effect=lambda x: x) - @patch("weaviate.data.crud_data.get_valid_uuid", side_effect=lambda x: x) - @patch("weaviate.data.crud_data.get_vector", side_effect=lambda x: x) - def test_create(self, mock_get_vector, mock_get_valid_uuid, mock_get_dict_from_object): - """ - Test the `create` method. - """ - - def reset(): - """ - Reset patched objects - """ - - mock_get_valid_uuid.reset_mock() # reset called - mock_get_vector.reset_mock() # reset called - mock_get_dict_from_object.reset_mock() # reset_called - - data_object = DataObject(Mock()) - - # invalid calls - class_name_error_message = lambda dt: f"Expected class_name of type str but was: {dt}" - requests_error_message = "Object was not added to Weaviate." - - # tests - with self.assertRaises(TypeError) as error: - data_object.create({"name": "Optimus Prime"}, ["Transformer"]) - check_error_message(self, error, class_name_error_message(list)) - mock_get_dict_from_object.assert_not_called() - mock_get_vector.assert_not_called() - mock_get_valid_uuid.assert_not_called() - - # test invalid consistency level - reset() - with self.assertRaises(ValueError) as error: - data_object.create( - {"name": "Optimus Prime"}, "class", "123", None, consistency_level="TWO" - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - mock_get_valid_uuid.assert_called() - - reset() - mock_obj = mock_connection_func("post", side_effect=RequestsConnectionError("Test!")) - data_object = DataObject(mock_obj) - with self.assertRaises(RequestsConnectionError) as error: - data_object.create({"name": "Alan Greenspan"}, "CoolestPersonEver") - check_error_message(self, error, requests_error_message) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - mock_get_valid_uuid.assert_not_called() - - reset() - mock_obj = mock_connection_func("post", status_code=204, return_json={}) - data_object = DataObject(mock_obj) - with self.assertRaises(UnexpectedStatusCodeException) as error: - data_object.create({"name": "Alan Greenspan"}, "CoolestPersonEver") - check_startswith_error_message(self, error, "Creating object") - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - mock_get_valid_uuid.assert_not_called() - - reset() - mock_obj = mock_connection_func( - "post", status_code=204, return_json={"error": [{"message": "already exists"}]} - ) - data_object = DataObject(mock_obj) - with self.assertRaises(ObjectAlreadyExistsException) as error: - data_object.create({"name": "Alan Greenspan"}, "CoolestPersonEver") - check_error_message(self, error, "None") - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - mock_get_valid_uuid.assert_not_called() - - reset() - mock_obj = mock_connection_func("post", status_code=204, return_json={}) - data_object = DataObject(mock_obj) - with self.assertRaises(UnexpectedStatusCodeException) as error: - data_object.create({"name": "Alan Greenspan"}, "CoolestPersonEver") - check_startswith_error_message(self, error, "Creating object") - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - mock_get_valid_uuid.assert_not_called() - - # # test valid calls - ## without vector argument - connection_mock = mock_connection_func("post", return_json={"id": 0}, status_code=200) - data_object = DataObject(connection_mock) - - object_ = { - "lyrics": "da da dadadada dada, da da dadadada da, da da dadadada da, da da dadadada da Tequila" - } - class_name = "KaraokeSongs" - vector = [1.0, 2.0] - id_ = "ae6d51d6-b4ea-5a03-a808-6aae990bdebf" - - rest_object = {"class": class_name, "properties": object_, "id": id_} - - reset() - uuid = data_object.create(object_, class_name, id_) - self.assertEqual(uuid, "0") - connection_mock.post.assert_called_with( - path="/objects", weaviate_object=rest_object, params={} - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - mock_get_valid_uuid.assert_called() - - ## with vector argument - connection_mock = mock_connection_func("post", return_json={"id": 0}, status_code=200) - data_object = DataObject(connection_mock) - - object_ = { - "lyrics": "da da dadadada dada, da da dadadada da, da da dadadada da, da da dadadada da Tequila" - } - class_name = "KaraokeSongs" - vector = [1.0, 2.0] - id_ = "ae6d51d6-b4ea-5a03-a808-6aae990bdebf" - - rest_object = {"class": class_name, "properties": object_, "vector": vector, "id": id_} - - reset() - uuid = data_object.create(object_, class_name, id_, vector, "ALL") - self.assertEqual(uuid, "0") - connection_mock.post.assert_called_with( - path="/objects", weaviate_object=rest_object, params={"consistency_level": "ALL"} - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_called() - mock_get_valid_uuid.assert_called() - - reset() - # uncapitalized class_names should be capitalized - uuid = data_object.create(object_, "karaokeSongs", id_, vector) - self.assertEqual(uuid, "0") - connection_mock.post.assert_called_with( - path="/objects", weaviate_object=rest_object, params={} - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_called() - mock_get_valid_uuid.assert_called() - - @patch("weaviate.data.crud_data._get_dict_from_object", side_effect=lambda x: x) - @patch("weaviate.data.crud_data.get_vector", side_effect=lambda x: x) - def test_update(self, mock_get_vector, mock_get_dict_from_object): - """ - Test the `update` method. - """ - uuid = "ae6d51d6-b4ea-5a03-a808-6aae990bdebf" - data_object = DataObject(Mock()) - - # error messages - class_type_error_message = "Class must be type str" - uuid_type_error_message = ( - lambda dt: f"'uuid' must be of type str or uuid.UUID, but was: {dt}" - ) - uuid_value_error_message = "Not valid 'uuid' or 'uuid' can not be extracted from value" - requests_error_message = "Object was not updated." - unexpected_error_message = "Update of the object not successful" - - with self.assertRaises(ValueError) as error: - data_object.update({"A": "B"}, "class", uuid, consistency_level="Unknown") - check_error_message(self, error, "'Unknown' is not a valid ConsistencyLevel") - mock_get_dict_from_object.assert_not_called() - mock_get_vector.assert_not_called() - - with self.assertRaises(TypeError) as error: - data_object.update({"A": "B"}, 35, uuid) - check_error_message(self, error, class_type_error_message) - mock_get_dict_from_object.assert_not_called() - mock_get_vector.assert_not_called() - - with self.assertRaises(TypeError) as error: - data_object.update({"A": "B"}, "Class", 1238234) - check_error_message(self, error, uuid_type_error_message(int)) - mock_get_dict_from_object.assert_not_called() - mock_get_vector.assert_not_called() - - with self.assertRaises(ValueError) as error: - data_object.update({"A": "B"}, "Class", "NOT-A-valid-uuid") - check_error_message(self, error, uuid_value_error_message) - mock_get_dict_from_object.assert_not_called() - mock_get_vector.assert_not_called() - - mock_obj = mock_connection_func("patch", side_effect=RequestsConnectionError("Test!")) - data_object = DataObject(mock_obj) - with self.assertRaises(RequestsConnectionError) as error: - data_object.update( - {"name": "Alan Greenspan"}, - "CoolestPersonEver", - uuid, - ) - check_error_message(self, error, requests_error_message) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - with self.assertRaises(UnexpectedStatusCodeException) as error: - mock_obj = mock_connection_func("patch", status_code=200, return_json={}) - data_object = DataObject(mock_obj) - data_object.update( - {"name": "Alan Greenspan"}, - "CoolestPersonEver", - uuid, - ) - check_startswith_error_message(self, error, unexpected_error_message) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - # test valid calls - ## without vector argument - connection_mock = mock_connection_func("patch", status_code=204) - data_object = DataObject(connection_mock) - data_object.update({"A": "B"}, "Class", uuid) - weaviate_obj = { - "id": uuid, - "class": "Class", - "properties": {"A": "B"}, - } - connection_mock.patch.assert_called_with( - path="/objects/ae6d51d6-b4ea-5a03-a808-6aae990bdebf", - weaviate_object=weaviate_obj, - params={}, - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - ### with uncapitalized class_name - connection_mock = mock_connection_func("patch", status_code=204) - data_object = DataObject(connection_mock) - data_object.update({"A": "B"}, "class", uuid) - weaviate_obj = { - "id": "ae6d51d6-b4ea-5a03-a808-6aae990bdebf", - "class": "Class", - "properties": {"A": "B"}, - } - connection_mock.patch.assert_called_with( - path="/objects/ae6d51d6-b4ea-5a03-a808-6aae990bdebf", - weaviate_object=weaviate_obj, - params={}, - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - ## with vector argument - connection_mock = mock_connection_func("patch", status_code=204) - data_object = DataObject(connection_mock) - data_object.update( - {"A": "B"}, - "Class", - uuid, - vector=[2.0, 4.0], - consistency_level=ConsistencyLevel.ONE, - ) - weaviate_obj = { - "id": uuid, - "class": "Class", - "properties": {"A": "B"}, - "vector": [2.0, 4.0], - } - connection_mock.patch.assert_called_with( - path="/objects/ae6d51d6-b4ea-5a03-a808-6aae990bdebf", - weaviate_object=weaviate_obj, - params={"consistency_level": "ONE"}, - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_called() - - @patch("weaviate.data.crud_data._get_dict_from_object", side_effect=lambda x: x) - @patch("weaviate.data.crud_data.get_vector", side_effect=lambda x: x) - def test_replace(self, mock_get_vector, mock_get_dict_from_object): - """ - Test the `replace` method. - """ - uuid = "27be9d8d-1da1-4d52-821f-bc7e2a25247d" - # test invalid consistency level - data_object = DataObject(Mock) - with self.assertRaises(ValueError) as error: - data_object.replace( - {"name": "Alan Greenspan"}, - "CoolestPersonEver", - "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - consistency_level="Unknown", - ) - mock_get_dict_from_object.assert_not_called() - mock_get_vector.assert_not_called() - - # error messages - requests_error_message = "Object was not replaced." - unexpected_error_message = "Replace object" - - # test exceptions - mock_obj = mock_connection_func("put", side_effect=RequestsConnectionError("Test!")) - data_object = DataObject(mock_obj) - with self.assertRaises(RequestsConnectionError) as error: - data_object.replace( - {"name": "Alan Greenspan"}, - "CoolestPersonEver", - uuid, - ) - check_error_message(self, error, requests_error_message) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - mock_obj = mock_connection_func("put", status_code=204, return_json={}) - data_object = DataObject(mock_obj) - with self.assertRaises(UnexpectedStatusCodeException) as error: - data_object.replace( - {"name": "Alan Greenspan"}, - "CoolestPersonEver", - uuid, - ) - check_startswith_error_message(self, error, unexpected_error_message) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - # test valid calls - ## without vector argument - connection_mock = mock_connection_func("put") - data_object = DataObject(connection_mock) - data_object.replace({"A": 2}, "Hero", uuid) - weaviate_obj = { - "id": "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - "class": "Hero", - "properties": {"A": 2}, - } - connection_mock.put.assert_called_with( - path="/objects/27be9d8d-1da1-4d52-821f-bc7e2a25247d", - weaviate_object=weaviate_obj, - params={}, - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - ### with uncapitalized class_name - connection_mock = mock_connection_func("put") - data_object = DataObject(connection_mock) - data_object.replace({"A": 2}, "hero", uuid) - weaviate_obj = { - "id": "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - "class": "Hero", - "properties": {"A": 2}, - } - connection_mock.put.assert_called_with( - path="/objects/27be9d8d-1da1-4d52-821f-bc7e2a25247d", - weaviate_object=weaviate_obj, - params={}, - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - # with vector argument - connection_mock = mock_connection_func("put") - data_object = DataObject(connection_mock) - data_object.replace( - {"A": 2}, - "Hero", - uuid, - vector=[3.0, 5, 7], - consistency_level="ONE", - ) - weaviate_obj = { - "id": "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - "class": "Hero", - "properties": {"A": 2}, - "vector": [3.0, 5, 7], - } - connection_mock.put.assert_called_with( - path="/objects/27be9d8d-1da1-4d52-821f-bc7e2a25247d", - weaviate_object=weaviate_obj, - params={"consistency_level": "ONE"}, - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_called() - - def test_delete(self): - """ - Test the `delete` method. - """ - uuid = "b36268d4-a6b5-5274-985f-45f13ce0c642" - connection_mock = Mock() - connection_mock.server_version = "1.13.2" - data_object = DataObject(connection_mock) - - # error messages - uuid_type_error_message = ( - lambda dt: f"'uuid' must be of type str or uuid.UUID, but was: {dt}" - ) - uuid_value_error_message = "Not valid 'uuid' or 'uuid' can not be extracted from value" - requests_error_message = "Object could not be deleted." - unexpected_error_message = "Delete object" - - with self.assertRaises(TypeError) as error: - data_object.delete(4) - check_error_message(self, error, uuid_type_error_message(int)) - - with self.assertRaises(ValueError) as error: - data_object.delete("Hallo World") - check_error_message(self, error, uuid_value_error_message) - - ## test invalid consistency level - with self.assertRaises(ValueError) as error: - data_object.delete(uuid=uuid, consistency_level="Unknown") - - connection_mock = mock_connection_func( - "delete", side_effect=RequestsConnectionError("Test!") - ) - data_object = DataObject(connection_mock) - with self.assertRaises(RequestsConnectionError) as error: - data_object.delete("b36268d4-a6b5-5274-985f-45f13ce0c642") - check_error_message(self, error, requests_error_message) - - connection_mock = mock_connection_func("delete", status_code=405) - data_object = DataObject(connection_mock) - with self.assertRaises(UnexpectedStatusCodeException) as error: - data_object.delete("b36268d4-a6b5-5274-985f-45f13ce0c642") - check_startswith_error_message(self, error, unexpected_error_message) - - # 1. Successfully delete something - connection_mock = mock_connection_func("delete", status_code=204) - data_object = DataObject(connection_mock) - - object_id = "b36268d4-a6b5-5274-985f-45f13ce0c642" - data_object.delete(object_id) - connection_mock.delete.assert_called_with( - path="/objects/" + object_id, - params={}, - ) - - data_object.delete(object_id, None, ConsistencyLevel.ALL) - connection_mock.delete.assert_called_with( - path="/objects/" + object_id, - params={"consistency_level": "ALL"}, - ) - - def test_get_by_id(self): - """ - Test the `get_by_id` method. - """ - - data_object = DataObject(Mock()) - - mock_get = Mock(return_value="Test") - data_object.get = mock_get - data_object.get_by_id( - uuid="UUID", additional_properties=["Test", "Array"], with_vector=True - ) - mock_get.assert_called_with( - uuid="UUID", - class_name=None, - additional_properties=["Test", "Array"], - with_vector=True, - node_name=None, - consistency_level=None, - tenant=None, - ) - - data_object.get_by_id(uuid="UUID2", additional_properties=["Test"], with_vector=False) - mock_get.assert_called_with( - uuid="UUID2", - class_name=None, - additional_properties=["Test"], - with_vector=False, - node_name=None, - consistency_level=None, - tenant=None, - ) - - data_object.get_by_id( - uuid="UUID3", - additional_properties=["Test"], - with_vector=False, - consistency_level=ConsistencyLevel.QUORUM, - ) - mock_get.assert_called_with( - uuid="UUID3", - class_name=None, - additional_properties=["Test"], - with_vector=False, - node_name=None, - consistency_level=ConsistencyLevel.QUORUM, - tenant=None, - ) - - data_object.get_by_id( - uuid="UUID4", additional_properties=["Test"], with_vector=False, node_name="node1" - ) - mock_get.assert_called_with( - uuid="UUID4", - class_name=None, - additional_properties=["Test"], - with_vector=False, - node_name="node1", - consistency_level=None, - tenant=None, - ) - - connection_mock = Mock() - connection_mock.server_version = "1.17.0" - data_object = DataObject(connection_mock) - with self.assertRaises(ValueError) as error: - data_object.get_by_id( - uuid="UUID4", - class_name="SomeClass", - additional_properties=["Test"], - with_vector=False, - consistency_level=12345, - ) - assert "123" in error - with self.assertRaises(ValueError) as error: - data_object.get_by_id( - uuid="UUID4", - class_name="SomeClass", - additional_properties=["Test"], - with_vector=False, - consistency_level="all", - ) - assert "all" in error - with self.assertRaises(ValueError) as error: - data_object.get_by_id( - uuid="UUID4", - class_name="SomeClass", - additional_properties=["Test"], - with_vector=False, - consistency_level={"consistency_level": "ALL"}, - ) - assert "consistency_level" in error - - @patch("weaviate.data.crud_data._get_params") - def test_get(self, mock_get_params): - """ - Test the `get` method. - """ - - # error messages - requests_error_message = "Could not get object/s." - unexpected_error_message = "Get object/s" - - # test exceptions - - data_object = DataObject( - mock_connection_func("get", side_effect=RequestsConnectionError("Test!")) - ) - with self.assertRaises(RequestsConnectionError) as error: - data_object.get() - check_error_message(self, error, requests_error_message) - - data_object = DataObject(mock_connection_func("get", status_code=405)) - with self.assertRaises(UnexpectedStatusCodeException) as error: - data_object.get() - check_startswith_error_message(self, error, unexpected_error_message) - - # test valid calls - return_value_get = {"my_key": 12341} - mock_get_params.return_value = {"include": "test1,test2"} - connection_mock = mock_connection_func("get", return_json=return_value_get, status_code=200) - data_object = DataObject(connection_mock) - result = data_object.get() - self.assertEqual(result, return_value_get) - connection_mock.get.assert_called_with(path="/objects", params={"include": "test1,test2"}) - - return_value_get = {"my_key": "12341"} - mock_get_params.return_value = {"include": "test1,test2"} - connection_mock = mock_connection_func("get", return_json=return_value_get, status_code=200) - data_object = DataObject(connection_mock) - result = data_object.get(uuid="1d420c9c98cb11ec9db61e008a366d49") - self.assertEqual(result, return_value_get) - connection_mock.get.assert_called_with( - path="/objects/1d420c9c-98cb-11ec-9db6-1e008a366d49", params={"include": "test1,test2"} - ) - - def test_exists(self): - """ - Test the `exists` method. - """ - uuid = "1d420c9c-98cb-11ec-9db6-1e008a366d49" - # error messages - requests_error_message = "Could not check if object exist." - unexpected_error_message = "Object exists" - - # test exceptions - data_object = DataObject(mock_connection_func("head")) - with self.assertRaises(ValueError) as error: - data_object.exists(uuid=uuid, consistency_level="Unknown") - - data_object = DataObject( - mock_connection_func("head", side_effect=RequestsConnectionError("Test!")) - ) - with self.assertRaises(RequestsConnectionError) as error: - data_object.exists(uuid="1d420c9c98cb11ec9db61e008a366d49") - check_error_message(self, error, requests_error_message) - - data_object = DataObject(mock_connection_func("head", status_code=200)) - with self.assertRaises(UnexpectedStatusCodeException) as error: - data_object.exists(uuid="1d420c9c98cb11ec9db61e008a366d49") - check_startswith_error_message(self, error, unexpected_error_message) - - # test valid calls - connection_mock = mock_connection_func("head", status_code=204) - data_object = DataObject(connection_mock) - result = data_object.exists(uuid="1d420c9c98cb11ec9db61e008a366d49") - self.assertEqual(result, True) - connection_mock.head.assert_called_with( - path="/objects/1d420c9c-98cb-11ec-9db6-1e008a366d49", - params={}, - ) - - connection_mock = mock_connection_func("head", status_code=404) - data_object = DataObject(connection_mock) - result = data_object.exists(uuid, None, ConsistencyLevel.QUORUM) - self.assertEqual(result, False) - connection_mock.head.assert_called_with( - path="/objects/1d420c9c-98cb-11ec-9db6-1e008a366d49", - params={"consistency_level": "QUORUM"}, - ) - - @patch("weaviate.data.crud_data._get_dict_from_object", side_effect=lambda x: x) - @patch("weaviate.data.crud_data.get_vector", side_effect=lambda x: x) - def test_validate(self, mock_get_vector, mock_get_dict_from_object): - """ - Test the `validate` method. - """ - - data_object = DataObject(Mock()) - - # error messages - uuid_type_error_message = ( - lambda dt: f"'uuid' must be of type str or uuid.UUID, but was: {dt}" - ) - class_name_error_message = lambda dt: f"Expected class_name of type `str` but was: {dt}" - requests_error_message = "Object was not validated against Weaviate." - unexpected_error_message = "Validate object" - - # test exceptions - with self.assertRaises(TypeError) as error: - data_object.validate({}, "Name", 1) - check_error_message(self, error, uuid_type_error_message(int)) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - with self.assertRaises(TypeError) as error: - data_object.validate({}, ["Name"], "73802305-c0da-427e-b21c-d6779a22f35f") - check_error_message(self, error, class_name_error_message(list)) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - data_object = DataObject( - mock_connection_func("post", side_effect=RequestsConnectionError("Test!")) - ) - with self.assertRaises(RequestsConnectionError) as error: - data_object.validate( - {"name": "Alan Greenspan"}, - "CoolestPersonEver", - "73802305-c0da-427e-b21c-d6779a22f35f", - ) - check_error_message(self, error, requests_error_message) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - data_object = DataObject(mock_connection_func("post", status_code=204, return_json={})) - with self.assertRaises(UnexpectedStatusCodeException) as error: - data_object.validate( - {"name": "Alan Greenspan"}, - "CoolestPersonEver", - "73802305-c0da-427e-b21c-d6779a22f35f", - ) - check_startswith_error_message(self, error, unexpected_error_message) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - # test valid calls - # test for status_code 200 without vector argument - connection_mock = mock_connection_func("post", status_code=200) - data_object = DataObject(connection_mock) - - response = data_object.validate({"A": 2}, "Hero", "27be9d8d-1da1-4d52-821f-bc7e2a25247d") - self.assertEqual(response, {"error": None, "valid": True}) - - weaviate_obj = { - "id": "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - "class": "Hero", - "properties": {"A": 2}, - } - connection_mock.post.assert_called_with( - path="/objects/validate", weaviate_object=weaviate_obj - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - ### with uncapitalized class_name - connection_mock = mock_connection_func("post", status_code=200) - data_object = DataObject(connection_mock) - - response = data_object.validate({"A": 2}, "hero", "27be9d8d-1da1-4d52-821f-bc7e2a25247d") - self.assertEqual(response, {"error": None, "valid": True}) - - weaviate_obj = { - "id": "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - "class": "Hero", - "properties": {"A": 2}, - } - connection_mock.post.assert_called_with( - path="/objects/validate", weaviate_object=weaviate_obj - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - # test for status_code 422 - connection_mock = mock_connection_func( - "post", status_code=422, return_json={"error": "Not OK!"} - ) - data_object = DataObject(connection_mock) - - response = data_object.validate({"A": 2}, "Hero", "27be9d8d-1da1-4d52-821f-bc7e2a25247d") - self.assertEqual(response, {"error": "Not OK!", "valid": False}) - - weaviate_obj = { - "id": "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - "class": "Hero", - "properties": {"A": 2}, - } - connection_mock.post.assert_called_with( - path="/objects/validate", weaviate_object=weaviate_obj - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_not_called() - - # test for status_code 200 with vector argument - connection_mock = mock_connection_func("post", status_code=200) - data_object = DataObject(connection_mock) - - response = data_object.validate( - {"A": 2}, "Hero", "27be9d8d-1da1-4d52-821f-bc7e2a25247d", vector=[-9.8, 6.66] - ) - self.assertEqual(response, {"error": None, "valid": True}) - - weaviate_obj = { - "id": "27be9d8d-1da1-4d52-821f-bc7e2a25247d", - "class": "Hero", - "properties": {"A": 2}, - "vector": [-9.8, 6.66], - } - connection_mock.post.assert_called_with( - path="/objects/validate", weaviate_object=weaviate_obj - ) - mock_get_dict_from_object.assert_called() - mock_get_vector.assert_called() - - def test__get_params(self): - """ - Test the `_get_params` function. - """ - - from weaviate.data.crud_data import _get_params - - # error messages - type_error_message = lambda dt: f"Additional properties must be of type list but are {dt}" - - with self.assertRaises(TypeError) as error: - _get_params("Test", False) - check_error_message(self, error, type_error_message(str)) - - self.assertEqual(_get_params(["test1", "test2"], False), {"include": "test1,test2"}) - self.assertEqual(_get_params(None, True), {"include": "vector"}) - self.assertEqual(_get_params([], True), {"include": "vector"}) - self.assertEqual(_get_params(["test1", "test2"], True), {"include": "test1,test2,vector"}) diff --git a/test/gql/__init__.py b/test/gql/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/gql/test_aggregate.py b/test/gql/test_aggregate.py deleted file mode 100644 index 1fcd59035..000000000 --- a/test/gql/test_aggregate.py +++ /dev/null @@ -1,206 +0,0 @@ -import unittest -from typing import List, Callable, Tuple -from unittest.mock import patch - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.gql.aggregate import AggregateBuilder - - -class TestAggregateBuilder(unittest.TestCase): - def setUp(self): - self.aggregate = AggregateBuilder("Object", None) - - def test_with_meta_count(self): - """ - Test the `with_meta_count` method. - """ - - query = self.aggregate.with_meta_count().build() - self.assertEqual("{Aggregate{Object{meta{count}}}}", query) - - def test_with_fields(self): - """ - Test the `with_fields` method. - """ - - query = self.aggregate.with_fields("size { mean }").build() - self.assertEqual("{Aggregate{Object{size { mean }}}}", query) - - def test_with_where(self): - """ - Test the `with_where` method. - """ - - query = ( - self.aggregate.with_meta_count() - .with_where({"operator": "Equal", "valueString": "B", "path": ["name"]}) - .build() - ) - self.assertEqual( - '{Aggregate{Object(where: {path: ["name"] operator: Equal valueString: "B"} ){meta{count}}}}', - query, - ) - - def test_group_by_filter(self): - """ - Test the `with_group_by_filter` method. - """ - - query = ( - self.aggregate.with_group_by_filter(["name"]) - .with_fields("groupedBy { value }") - .with_fields("name { count }") - .build() - ) - self.assertEqual( - '{Aggregate{Object(groupBy: ["name"]){groupedBy { value }name { count }}}}', query - ) - - def test_with_limit(self): - """ - Test the `with_limit` method. - """ - - query = ( - self.aggregate.with_meta_count() - .with_where({"operator": "Equal", "valueString": "B", "path": ["name"]}) - .with_limit(10) - .build() - ) - self.assertEqual( - '{Aggregate{Object(where: {path: ["name"] operator: Equal valueString: "B"} limit: 10){meta{count}}}}', - query, - ) - - test_near_media_param_list: List[ - Tuple[str, str, str, Callable[[AggregateBuilder, str, str, bool], AggregateBuilder]] - ] = [ - ( - "audio", - "test_audio", - "nearAudio", - lambda b, k, v, e: b.with_near_audio({k: v, "certainty": 0.55}, encode=e), - ), - ( - "video", - "test_video", - "nearVideo", - lambda b, k, v, e: b.with_near_video({k: v, "certainty": 0.55}, encode=e), - ), - ( - "depth", - "test_depth", - "nearDepth", - lambda b, k, v, e: b.with_near_depth({k: v, "certainty": 0.55}, encode=e), - ), - ( - "thermal", - "test_thermal", - "nearThermal", - lambda b, k, v, e: b.with_near_thermal({k: v, "certainty": 0.55}, encode=e), - ), - ( - "imu", - "test_imu", - "nearIMU", - lambda b, k, v, e: b.with_near_imu({k: v, "certainty": 0.55}, encode=e), - ), - ] - - def test_near_media(self): - """ - Test the `with_near_` method. - """ - for key, value, type_, fn in self.test_near_media_param_list: - with self.subTest(key=key, value=value): - with patch( - "weaviate.gql.aggregate.file_encoder_b64", side_effect=lambda x: "test_call" - ) as mocked: - # valid calls - ## encode False - query = fn( - AggregateBuilder("Person", None).with_fields("name"), key, value, False - ).build() - self.assertEqual( - f'{{Aggregate{{Person({type_}: {{{key}: "{value}" certainty: 0.55}} ){{name}}}}}}', - query, - ) - mocked.assert_not_called() - - ## encode True - query = fn( - AggregateBuilder("Person", None).with_fields("name"), key, value, True - ).build() - self.assertEqual( - f'{{Aggregate{{Person({type_}: {{{key}: "test_call" certainty: 0.55}} ){{name}}}}}}', - query, - ) - mocked.assert_called() - - # invalid calls - near_error_msg = "Cannot use multiple 'near' filters, or a 'near' filter along with a 'ask' filter!" - - near_text = { - "concepts": "computer", - "moveTo": {"concepts": ["science"], "force": 0.5}, - } - with self.assertRaises(AttributeError) as error: - fn( - AggregateBuilder("Person", None) - .with_fields("name") - .with_near_text(near_text), - key, - value, - True, - ) - check_error_message(self, error, near_error_msg) - - def test_do(self): - """ - Test the `do` method. - """ - - # test exceptions - requests_error_message = "Query was not successful." - - # requests.exceptions.ConnectionError - mock_obj = mock_connection_func("post", side_effect=RequestsConnectionError("Test")) - self.aggregate._connection = mock_obj - with self.assertRaises(RequestsConnectionError) as error: - self.aggregate.do() - check_error_message(self, error, requests_error_message) - - # weaviate.UnexpectedStatusCodeException - mock_obj = mock_connection_func("post", status_code=404) - self.aggregate._connection = mock_obj - with self.assertRaises(UnexpectedStatusCodeException) as error: - self.aggregate.do() - check_startswith_error_message(self, error, "Query was not successful") - - filter_name = {"path": ["name"], "operator": "Equal", "valueString": "B"} - - self.aggregate.with_group_by_filter(["name"]).with_fields( - "groupedBy { value }" - ).with_fields("name { count }").with_where(filter_name) - expected_gql_clause = '{Aggregate{Object(where: {path: ["name"] operator: Equal valueString: "B"} groupBy: ["name"]){groupedBy { value }name { count }}}}' - - mock_obj = mock_connection_func("post", status_code=200, return_json={"status": "OK!"}) - self.aggregate._connection = mock_obj - self.assertEqual(self.aggregate.do(), {"status": "OK!"}) - mock_obj.post.assert_called_with( - path="/graphql", weaviate_object={"query": expected_gql_clause} - ) - - def test_uncapitalized_class_name(self): - """ - Test the uncapitalized class_name. - """ - - aggregate = AggregateBuilder("Test", None) - self.assertEqual(aggregate._class_name, "Test") - - aggregate = AggregateBuilder("test", None) - self.assertEqual(aggregate._class_name, "Test") diff --git a/test/gql/test_filter.py b/test/gql/test_filter.py deleted file mode 100644 index d1c38fd25..000000000 --- a/test/gql/test_filter.py +++ /dev/null @@ -1,1277 +0,0 @@ -import unittest - -from test.util import check_error_message, check_startswith_error_message -from weaviate.gql.filter import ( - NearText, - NearVector, - NearObject, - NearImage, - MediaType, - NearVideo, - NearAudio, - NearDepth, - NearThermal, - NearIMU, - Where, - Ask, - WHERE_OPERATORS, - VALUE_TYPES, -) - - -def helper_get_test_filter(filter_type, value): - return {"path": ["name"], "operator": "Equal", filter_type: value} - - -class TestNearText(unittest.TestCase): - def move_x_test_case(self, move: str): - """ - Test the "moveTo" or the "moveAwayFrom" clause. - - Parameters - ---------- - move : str - The "moveTo" or the "moveAwayFrom" clause name. - """ - - type_error_msg = ( - lambda dt: f"'moveXXX' key-value is expected to be of type but is {dt}!" - ) - concepts_objects_error_msg = "The 'move' clause should contain `concepts` OR/AND `objects`!" - objects_type_error_msg = ( - lambda dt: f"'objects' key-value is expected to be of type (, ) but is {dt}!" - ) - object_value_error_msg = ( - "Each object from the `move` clause should have ONLY `id` OR `beacon`!" - ) - concept_value_error_msg = lambda dt: ( - f"'concepts' key-value is expected to be of type (, ) but is {dt}!" - ) - force_error_msg = "'move' clause needs to state a 'force'" - force_type_error_msg = lambda dt: ( - f"'force' key-value is expected to be of type but is {dt}!" - ) - - with self.assertRaises(TypeError) as error: - NearText({"concepts": "Some_concept", move: "0.5"}) - check_error_message(self, error, type_error_msg(str)) - - with self.assertRaises(ValueError) as error: - NearText({"concepts": "Some_concept", move: {}}) - check_error_message(self, error, concepts_objects_error_msg) - - with self.assertRaises(TypeError) as error: - NearText({"concepts": "Some_concept", move: {"concepts": set("something")}}) - check_error_message(self, error, concept_value_error_msg(set)) - - with self.assertRaises(ValueError) as error: - NearText( - { - "concepts": "Some_concept", - move: { - "concepts": "something", - }, - } - ) - check_error_message(self, error, force_error_msg) - - with self.assertRaises(TypeError) as error: - NearText( - { - "concepts": "Some_concept", - move: { - "objects": 1234, - }, - } - ) - check_error_message(self, error, objects_type_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearText( - { - "concepts": "Some_concept", - move: { - "objects": {}, - }, - } - ) - check_error_message(self, error, object_value_error_msg) - - with self.assertRaises(ValueError) as error: - NearText( - { - "concepts": "Some_concept", - move: { - "objects": {"id": 1, "beacon": 2}, - }, - } - ) - check_error_message(self, error, object_value_error_msg) - - with self.assertRaises(ValueError) as error: - NearText( - { - "concepts": "Some_concept", - move: { - "objects": {"test_id": 1}, - }, - } - ) - check_error_message(self, error, object_value_error_msg) - - with self.assertRaises(TypeError) as error: - NearText( - { - "concepts": "Some_concept", - move: {"concepts": "something", "objects": [{"id": 1}], "force": True}, - } - ) - check_error_message(self, error, force_type_error_msg(bool)) - - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - content_error_msg = f"NearText filter is expected to be type dict but is {list}" - concept_error_msg = "No concepts in content" - concept_value_error_msg = lambda actual_type: ( - f"'concepts' key-value is expected to be of type (, ) but is {actual_type}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - autocorrect_error_msg = lambda dtype: ( - f"'autocorrect' key-value is expected to be of type but is {dtype}!" - ) - - ## test "concepts" - with self.assertRaises(TypeError) as error: - NearText(["concepts", "Some_concept"]) - check_error_message(self, error, content_error_msg) - - with self.assertRaises(ValueError) as error: - NearText({"INVALID": "Some_concept"}) - check_error_message(self, error, concept_error_msg) - - with self.assertRaises(TypeError) as error: - NearText({"concepts": set("Some_concept")}) - check_error_message(self, error, concept_value_error_msg(set)) - - ## test "certainty" - with self.assertRaises(TypeError) as error: - NearText({"concepts": "Some_concept", "certainty": "0.5"}) - check_error_message(self, error, certainty_error_msg(str)) - - ## test "certainty" - with self.assertRaises(TypeError) as error: - NearText({"concepts": "Some_concept", "autocorrect": [True]}) - check_error_message(self, error, autocorrect_error_msg(list)) - - ## test "moveTo" - self.move_x_test_case("moveTo") - ## test "moveAwayFrom" - self.move_x_test_case("moveAwayFrom") - - # test valid calls - NearText({"concepts": "Some_concept"}) - NearText({"concepts": ["Some_concept", "Some_concept_2"]}) - NearText({"concepts": "Some_concept", "certainty": 0.75}) - NearText({"concepts": "Some_concept", "certainty": 0.75, "autocorrect": True}) - NearText( - {"concepts": "Some_concept", "moveTo": {"concepts": "moveToConcepts", "force": 0.75}} - ) - NearText( - { - "concepts": "Some_concept", - "moveAwayFrom": {"concepts": "moveAwayFromConcepts", "force": 0.75}, - } - ) - NearText( - { - "concepts": "Some_concept", - "certainty": 0.75, - "moveAwayFrom": {"concepts": "moveAwayFromConcepts", "force": 0.75}, - "moveTo": {"concepts": "moveToConcepts", "force": 0.75}, - "autocorrect": False, - } - ) - - NearText( - { - "concepts": "Some_concept", - "certainty": 0.75, - "moveAwayFrom": {"objects": {"id": "test_id"}, "force": 0.75}, - "moveTo": { - "concepts": "moveToConcepts", - "objects": [{"id": "test_id"}, {"beacon": "Test_beacon"}], - "force": 0.75, - }, - "autocorrect": True, - } - ) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_text = NearText({"concepts": "Some_concept"}) - self.assertEqual(str(near_text), 'nearText: {concepts: ["Some_concept"]} ') - - near_text = NearText({"concepts": ["Some_concept", "Some_concept_2"]}) - self.assertEqual( - str(near_text), 'nearText: {concepts: ["Some_concept", "Some_concept_2"]} ' - ) - near_text = NearText({"concepts": "Some_concept", "certainty": 0.75}) - self.assertEqual(str(near_text), 'nearText: {concepts: ["Some_concept"] certainty: 0.75} ') - near_text = NearText({"concepts": "Some_concept", "autocorrect": True}) - self.assertEqual( - str(near_text), 'nearText: {concepts: ["Some_concept"] autocorrect: true} ' - ) - near_text = NearText({"concepts": "Some_concept", "autocorrect": False}) - self.assertEqual( - str(near_text), 'nearText: {concepts: ["Some_concept"] autocorrect: false} ' - ) - near_text = NearText( - {"concepts": "Some_concept", "moveTo": {"concepts": "moveToConcepts", "force": 0.75}} - ) - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["Some_concept"] moveTo: {force: 0.75 concepts: ["moveToConcepts"]}} ', - ) - near_text = NearText( - { - "concepts": "Some_concept", - "moveTo": { - "concepts": "moveToConcepts", - "force": 0.75, - "objects": {"id": "SOME_ID"}, - }, - } - ) - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["Some_concept"] moveTo: {force: 0.75 concepts: ["moveToConcepts"] objects: [{id: "SOME_ID"} ]}} ', - ) - - near_text = NearText( - { - "concepts": "Some_concept", - "moveTo": { - "force": 0.75, - "objects": [{"id": "SOME_ID"}, {"beacon": "SOME_BEACON"}], - }, - } - ) - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["Some_concept"] moveTo: {force: 0.75 objects: [{id: "SOME_ID"} {beacon: "SOME_BEACON"} ]}} ', - ) - - near_text = NearText( - { - "concepts": "Some_concept", - "moveAwayFrom": { - "concepts": "moveAwayFromConcepts", - "force": 0.75, - "objects": {"id": "SOME_ID"}, - }, - } - ) - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["Some_concept"] moveAwayFrom: {force: 0.75 concepts: ["moveAwayFromConcepts"] objects: [{id: "SOME_ID"} ]}} ', - ) - - near_text = NearText( - { - "concepts": "Some_concept", - "moveAwayFrom": { - "force": 0.75, - "objects": [{"id": "SOME_ID"}, {"beacon": "SOME_BEACON"}], - }, - } - ) - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["Some_concept"] moveAwayFrom: {force: 0.75 objects: [{id: "SOME_ID"} {beacon: "SOME_BEACON"} ]}} ', - ) - - near_text = NearText( - { - "concepts": "Some_concept", - "moveAwayFrom": {"concepts": "moveAwayFromConcepts", "force": 0.25}, - } - ) - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["Some_concept"] moveAwayFrom: {force: 0.25 concepts: ["moveAwayFromConcepts"]}} ', - ) - near_text = NearText( - { - "concepts": "Some_concept", - "certainty": 0.95, - "moveAwayFrom": {"concepts": "moveAwayFromConcepts", "force": 0.75}, - "moveTo": {"concepts": "moveToConcepts", "force": 0.25}, - "autocorrect": True, - } - ) - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["Some_concept"] certainty: 0.95 moveTo: {force: 0.25 concepts: ["moveToConcepts"]} moveAwayFrom: {force: 0.75 concepts: ["moveAwayFromConcepts"]} autocorrect: true} ', - ) - - # test it with references of objects - concepts = ["con1", "con2"] - move = {"concepts": "moveToConcepts", "force": 0.75} - - near_text = NearText({"concepts": concepts, "moveTo": move}) - concepts.append("con3") # should not be appended to the nearText clause - move["force"] = 2.00 # should not be appended to the nearText clause - self.assertEqual( - str(near_text), - 'nearText: {concepts: ["con1", "con2"] moveTo: {force: 0.75 concepts: ["moveToConcepts"]}} ', - ) - - -class TestNearVector(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # test exceptions - content_error_msg = "NearVector filter is expected to " f"be type dict but is {list}" - vector_error_msg = "\"No 'vector' key in `content` argument.\"" - vector_value_error_msg = ( - "The type of the 'vector' argument is not supported!\n" - "Supported types are `list`, 'numpy.ndarray`, `torch.Tensor`, `tf.Tensor`, `pd.Series`, and `pl.Series`" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - ## test "concepts" - with self.assertRaises(TypeError) as error: - NearVector(["concepts", "Some_concept"]) - check_error_message(self, error, content_error_msg) - - with self.assertRaises(KeyError) as error: - NearVector({"INVALID": "Some_concept"}) - check_error_message(self, error, vector_error_msg) - - with self.assertRaises(TypeError) as error: - NearVector({"vector": set("Some_concept")}) - check_error_message(self, error, vector_value_error_msg) - - ## test "certainty" - with self.assertRaises(TypeError) as error: - NearVector({"vector": [1.0, 2.0, 3.0, 4.0], "certainty": "0.5"}) - check_error_message(self, error, certainty_error_msg(str)) - - # test valid calls - NearVector({"vector": [1.0, 2.0, 3.0, 4.0]}) - NearVector({"vector": [1.0, 2.0, 3.0, 4.0], "certainty": 0.75}) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_vector = NearVector({"vector": [1.0, 2.0, 3.0, 4.0]}) - self.assertEqual(str(near_vector), "nearVector: {vector: [1.0, 2.0, 3.0, 4.0]} ") - near_vector = NearVector({"vector": [1.0, 2.0, 3.0, 4.0], "certainty": 0.75}) - self.assertEqual( - str(near_vector), "nearVector: {vector: [1.0, 2.0, 3.0, 4.0] certainty: 0.75} " - ) - - -class TestNearObject(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - - ## error messages - content_error_msg = lambda dt: f"NearObject filter is expected to be type dict but is {dt}" - beacon_id_error_msg = "The 'content' argument should contain EITHER `id` OR `beacon`!" - beacon_id_type_error_msg = lambda what, dt: ( - f"'{what}' key-value is expected to be of type but is {dt}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - with self.assertRaises(TypeError) as error: - NearObject(123, is_server_version_14=False) - check_error_message(self, error, content_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearObject({"id": 123, "beacon": 456}, is_server_version_14=False) - check_error_message(self, error, beacon_id_error_msg) - - with self.assertRaises(TypeError) as error: - NearObject( - { - "id": 123, - }, - is_server_version_14=False, - ) - check_error_message(self, error, beacon_id_type_error_msg("id", int)) - - with self.assertRaises(TypeError) as error: - NearObject( - { - "beacon": {123}, - }, - is_server_version_14=False, - ) - check_error_message(self, error, beacon_id_type_error_msg("beacon", set)) - - with self.assertRaises(TypeError) as error: - NearObject({"beacon": "test_beacon", "certainty": False}, is_server_version_14=False) - check_error_message(self, error, certainty_error_msg(bool)) - - # valid calls - - NearObject( - { - "id": "test_id", - }, - is_server_version_14=False, - ) - - NearObject({"beacon": "test_beacon", "certainty": 0.7}, is_server_version_14=False) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_object = NearObject( - { - "id": "test_id", - }, - is_server_version_14=False, - ) - self.assertEqual(str(near_object), 'nearObject: {id: "test_id"} ') - - near_object = NearObject({"id": "test_id", "certainty": 0.7}, is_server_version_14=False) - self.assertEqual(str(near_object), 'nearObject: {id: "test_id" certainty: 0.7} ') - - near_object = NearObject( - { - "beacon": "test_beacon", - }, - is_server_version_14=False, - ) - self.assertEqual(str(near_object), 'nearObject: {beacon: "test_beacon"} ') - - near_object = NearObject( - {"beacon": "test_beacon", "certainty": 0.0}, is_server_version_14=False - ) - self.assertEqual(str(near_object), 'nearObject: {beacon: "test_beacon" certainty: 0.0} ') - - -class TestNearImage(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - - ## error messages - content_error_msg = lambda dt: f"NearImage filter is expected to be type dict but is {dt}" - image_key_error_msg = '"content" is missing the mandatory key "image"!' - image_value_error_msg = ( - lambda dt: f"'image' key-value is expected to be of type but is {dt}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - with self.assertRaises(TypeError) as error: - NearImage(123) - check_error_message(self, error, content_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearImage({"id": "image_path.png", "certainty": 456}) - check_error_message(self, error, image_key_error_msg) - - with self.assertRaises(TypeError) as error: - NearImage({"image": True}) - check_error_message(self, error, image_value_error_msg(bool)) - - with self.assertRaises(TypeError) as error: - NearImage({"image": b"True"}) - check_error_message(self, error, image_value_error_msg(bytes)) - - with self.assertRaises(TypeError) as error: - NearImage({"image": "the_encoded_image", "certainty": False}) - check_error_message(self, error, certainty_error_msg(bool)) - - # valid calls - - NearImage( - { - "image": "test_image", - } - ) - - NearImage({"image": "test_image_2", "certainty": 0.7}) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_object = NearImage( - { - "image": "test_image", - } - ) - self.assertEqual(str(near_object), 'nearImage: {image: "test_image"} ') - - near_object = NearImage({"image": "test_image", "certainty": 0.7}) - self.assertEqual(str(near_object), 'nearImage: {image: "test_image" certainty: 0.7} ') - - -class TestWhere(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # test exceptions - content_error_msg = lambda dt: f"Where filter is expected to be type dict but is {dt}" - content_key_error_msg = "Filter is missing required fields `path` or `operands`. Given: " - path_key_error = "Filter is missing required field `operator`. Given: " - dtype_no_value_error_msg = "'value' field is either missing or incorrect: " - dtype_multiple_value_error_msg = "Multiple fields 'value' are not supported: " - operator_error_msg = ( - lambda op: f"Operator {op} is not allowed. Allowed operators are: {', '.join(WHERE_OPERATORS)}" - ) - geo_operator_value_type_mismatch_msg = ( - lambda op, vt: f"Operator {op} requires a value of type valueGeoRange. Given value type: {vt}" - ) - - with self.assertRaises(TypeError) as error: - Where(None) - check_error_message(self, error, content_error_msg(type(None))) - - with self.assertRaises(TypeError) as error: - Where("filter") - check_error_message(self, error, content_error_msg(str)) - - with self.assertRaises(ValueError) as error: - Where({}) - check_startswith_error_message(self, error, content_key_error_msg) - - with self.assertRaises(ValueError) as error: - Where({"path": "some_path"}) - check_startswith_error_message(self, error, path_key_error) - - with self.assertRaises(ValueError) as error: - Where({"path": "some_path", "operator": "Like"}) - check_startswith_error_message(self, error, dtype_no_value_error_msg) - - with self.assertRaises(ValueError) as error: - Where({"path": "some_path", "operator": "Like", "valueBoolean": True, "valueInt": 1}) - check_startswith_error_message(self, error, dtype_multiple_value_error_msg) - - with self.assertRaises(ValueError) as error: - Where({"operands": "some_path"}) - check_startswith_error_message(self, error, path_key_error) - - with self.assertRaises(TypeError) as error: - Where({"operands": "some_path", "operator": "Like"}) - check_error_message(self, error, content_error_msg(str)) - - with self.assertRaises(TypeError) as error: - Where({"operands": ["some_path"], "operator": "Like"}) - check_error_message(self, error, content_error_msg(str)) - - with self.assertRaises(ValueError) as error: - Where({"path": "some_path", "operator": "NotValid"}) - check_error_message(self, error, operator_error_msg("NotValid")) - - with self.assertRaises(ValueError) as error: - Where({"path": "some_path", "operator": "WithinGeoRange", "valueBoolean": True}) - check_error_message( - self, error, geo_operator_value_type_mismatch_msg("WithinGeoRange", "valueBoolean") - ) - - # test valid calls - Where({"path": "hasTheOneRing", "operator": "Equal", "valueBoolean": False}) - Where( - { - "operands": [ - {"path": "hasTheOneRing", "operator": "Equal", "valueBoolean": False}, - {"path": "hasFriend", "operator": "Equal", "valueText": "Samwise Gamgee"}, - ], - "operator": "And", - } - ) - - def test___str__(self) -> None: - """ - Test the `__str__` method. - """ - value_is_not_list_err = ( - lambda v, t: f"Must provide a list when constructing where filter for {t} with {v}" - ) - - test_filter = {"path": ["name"], "operator": "Equal", "valueString": "A"} - result = str(Where(test_filter)) - self.assertEqual('where: {path: ["name"] operator: Equal valueString: "A"} ', result) - - test_filter = { - "operator": "Or", - "operands": [ - {"path": ["name"], "operator": "Equal", "valueString": "Alan Truing"}, - {"path": ["name"], "operator": "Equal", "valueString": "John von Neumann"}, - ], - } - result = str(Where(test_filter)) - self.assertEqual( - 'where: {operator: Or operands: [{path: ["name"] operator: Equal valueString: "Alan Truing"}, {path: ["name"] operator: Equal valueString: "John von Neumann"}]} ', - result, - ) - - # test dataTypes - test_filter = helper_get_test_filter("valueText", "Test") - result = str(Where(test_filter)) - self.assertEqual('where: {path: ["name"] operator: Equal valueText: "Test"} ', result) - - test_filter = helper_get_test_filter("valueString", "Test") - result = str(Where(test_filter)) - self.assertEqual('where: {path: ["name"] operator: Equal valueString: "Test"} ', result) - - test_filter = helper_get_test_filter("valueText", "n\n") - result = str(Where(test_filter)) - self.assertEqual('where: {path: ["name"] operator: Equal valueText: "n "} ', result) - - test_filter = helper_get_test_filter("valueString", 'what is an "airport"?') - result = str(Where(test_filter)) - self.assertEqual( - 'where: {path: ["name"] operator: Equal valueString: "what is an \\"airport\\"?"} ', - result, - ) - - # test_filter = helper_get_test_filter("valueString", "this is an escape sequence \a") - # result = str(Where(test_filter)) - # self.assertEqual( - # 'where: {path: ["name"] operator: Equal valueString: "this is an escape sequence \\u0007"} ', - # result, - # ) - - # test_filter = helper_get_test_filter("valueString", "this is a hex value \u03A9") - # result = str(Where(test_filter)) - # self.assertEqual( - # 'where: {path: ["name"] operator: Equal valueString: "this is a hex value \\u03a9"} ', - # result, - # ) - - test_filter = helper_get_test_filter("valueText", "what is an 'airport'?") - result = str(Where(test_filter)) - self.assertEqual( - 'where: {path: ["name"] operator: Equal valueText: "what is an \'airport\'?"} ', result - ) - - test_filter = helper_get_test_filter("valueInt", 1) - result = str(Where(test_filter)) - self.assertEqual('where: {path: ["name"] operator: Equal valueInt: 1} ', result) - - test_filter = helper_get_test_filter("valueNumber", 1.4) - result = str(Where(test_filter)) - self.assertEqual('where: {path: ["name"] operator: Equal valueNumber: 1.4} ', result) - - test_filter = helper_get_test_filter("valueBoolean", True) - result = str(Where(test_filter)) - self.assertEqual('where: {path: ["name"] operator: Equal valueBoolean: true} ', result) - - test_filter = helper_get_test_filter("valueDate", "test-2021-02-02") - result = str(Where(test_filter)) - self.assertEqual( - 'where: {path: ["name"] operator: Equal valueDate: "test-2021-02-02"} ', result - ) - - geo_range = { - "geoCoordinates": {"latitude": 51.51, "longitude": -0.09}, - "distance": {"max": 2000}, - } - test_filter = helper_get_test_filter("valueGeoRange", geo_range) - result = str(Where(test_filter)) - self.assertEqual( - 'where: {path: ["name"] operator: Equal valueGeoRange: { geoCoordinates: { latitude: 51.51 longitude: -0.09 } distance: { max: 2000 }}} ', - str(result), - ) - - test_filter = { - "path": ["name"], - "operator": "ContainsAny", - "valueTextArray": ["A", "B\n"], - } - result = str(Where(test_filter)) - self.assertEqual( - 'where: {path: ["name"] operator: ContainsAny valueText: ["A","B "]} ', str(result) - ) - - test_filter = { - "path": ["name"], - "operator": "ContainsAll", - "valueStringArray": ["A", '"B"'], - } - result = str(Where(test_filter)) - self.assertEqual( - 'where: {path: ["name"] operator: ContainsAll valueString: ["A","\\"B\\""]} ', - str(result), - ) - - test_filter = {"path": ["name"], "operator": "ContainsAny", "valueIntArray": [1, 2]} - result = str(Where(test_filter)) - self.assertEqual( - 'where: {path: ["name"] operator: ContainsAny valueInt: [1, 2]} ', str(result) - ) - - test_filter = { - "path": ["name"], - "operator": "ContainsAny", - "valueStringArray": "A", - } - with self.assertRaises(TypeError) as error: - str(Where(test_filter)) - check_error_message(self, error, value_is_not_list_err("A", "valueStringArray")) - - test_filter = { - "path": ["name"], - "operator": "ContainsAll", - "valueTextArray": "A", - } - with self.assertRaises(TypeError) as error: - str(Where(test_filter)) - check_error_message(self, error, value_is_not_list_err("A", "valueTextArray")) - - test_filter = { - "path": ["name"], - "operator": "Equal", - "valueWrong": "whatever", - } - with self.assertRaises(ValueError) as error: - str(Where(test_filter)) - assert ( - error.exception.args[0] - == f"'value' field is either missing or incorrect: {test_filter}. Valid values are: {VALUE_TYPES}." - ) - - -class TestAskFilter(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # test exceptions - ## error messages - content_type_msg = lambda dt: f"Ask filter is expected to be type dict but is {dt}" - question_value_msg = 'Mandatory "question" key not present in the "content"!' - question_type_msg = ( - lambda dt: f"'question' key-value is expected to be of type but is {dt}!" - ) - certainty_type_msg = ( - lambda dt: f"'certainty' key-value is expected to be of type but is {dt}!" - ) - properties_type_msg = ( - lambda dt: f"'properties' key-value is expected to be of type (, ) but is {dt}!" - ) - autocorrect_type_msg = ( - lambda dt: f"'autocorrect' key-value is expected to be of type but is {dt}!" - ) - - with self.assertRaises(TypeError) as error: - Ask(None) - check_error_message(self, error, content_type_msg(type(None))) - - with self.assertRaises(ValueError) as error: - Ask({"certainty": 0.1}) - check_error_message(self, error, question_value_msg) - - with self.assertRaises(TypeError) as error: - Ask({"question": ["Who is the president of USA?"]}) - check_error_message(self, error, question_type_msg(list)) - - with self.assertRaises(TypeError) as error: - Ask({"question": "Who is the president of USA?", "certainty": "1.0"}) - check_error_message(self, error, certainty_type_msg(str)) - - with self.assertRaises(TypeError) as error: - Ask({"question": "Who is the president of USA?", "autocorrect": {"True"}}) - check_error_message(self, error, autocorrect_type_msg(set)) - - with self.assertRaises(TypeError) as error: - Ask( - { - "question": "Who is the president of USA?", - "certainty": 0.8, - "properties": ("prop1", "prop2"), - } - ) - check_error_message(self, error, properties_type_msg(tuple)) - - # valid calls - - content = { - "question": "Who is the president of USA?", - } - ask = Ask(content=content) - self.assertEqual(str(ask), f"ask: {{question: \"{content['question']}\"}} ") - - content = { - "question": "Who is the president of USA?", - "certainty": 0.8, - } - ask = Ask(content=content) - self.assertEqual( - str(ask), - ( - f"ask: {{question: \"{content['question']}\"" - f' certainty: {content["certainty"]}}} ' - ), - ) - - content = { - "question": 'Who is the president of "USA"?', - "certainty": 0.8, - } - ask = Ask(content=content) - self.assertEqual( - str(ask), - ( - f'ask: {{question: "Who is the president of \\"USA\\"?"' - f' certainty: {content["certainty"]}}} ' - ), - ) - - content = { - "question": "Who is the president of USA?", - "certainty": 0.8, - "properties": "prop1", - } - ask = Ask(content=content) - self.assertEqual( - str(ask), - ( - f"ask: {{question: \"{content['question']}\"" - f' certainty: {content["certainty"]}' - f' properties: ["prop1"]}} ' - ), - ) - - content = { - "question": "Who is the president of USA?", - "certainty": 0.8, - "properties": ["prop1", "prop2"], - } - ask = Ask(content=content) - self.assertEqual( - str(ask), - ( - f"ask: {{question: \"{content['question']}\"" - f' certainty: {content["certainty"]}' - f' properties: ["prop1", "prop2"]}} ' - ), - ) - - content = { - "question": "Who is the president of USA?", - "certainty": 0.8, - "properties": ["prop1", "prop2"], - "autocorrect": True, - } - ask = Ask(content=content) - self.assertEqual( - str(ask), - ( - f"ask: {{question: \"{content['question']}\"" - f' certainty: {content["certainty"]}' - ' properties: ["prop1", "prop2"] autocorrect: true} ' - ), - ) - - content = {"question": "Who is the president of USA?", "autocorrect": False} - ask = Ask(content=content) - self.assertEqual( - str(ask), (f"ask: {{question: \"{content['question']}\" autocorrect: false}} ") - ) - - content = { - "question": "Who is the president of USA?", - "certainty": 0.8, - "properties": ["prop1", "prop2"], - "autocorrect": True, - "rerank": True, - } - ask = Ask(content=content) - self.assertEqual( - str(ask), - ( - f"ask: {{question: \"{content['question']}\"" - f' certainty: {content["certainty"]}' - ' properties: ["prop1", "prop2"] autocorrect: true' - " rerank: true} " - ), - ) - - content = { - "question": "Who is the president of USA?", - "rerank": False, - } - ask = Ask(content=content) - self.assertEqual( - str(ask), (f"ask: {{question: \"{content['question']}\"" " rerank: false} ") - ) - - -class TestNearVideo(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - - ## error messages - media_type = MediaType.VIDEO - arg_name = media_type.value.capitalize() - content_error_msg = ( - lambda dt: f"Near{arg_name} filter is expected to be type dict but is {dt}" - ) - media_key_error_msg = f'"content" is missing the mandatory key "{media_type.value}"!' - media_value_error_msg = ( - lambda dt: f"'{media_type.value}' key-value is expected to be of type but is {dt}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - with self.assertRaises(TypeError) as error: - NearVideo(123) - check_error_message(self, error, content_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearVideo({"id": "video_path.avi", "certainty": 456}) - check_error_message(self, error, media_key_error_msg) - - with self.assertRaises(TypeError) as error: - NearVideo({"video": True}) - check_error_message(self, error, media_value_error_msg(bool)) - - with self.assertRaises(TypeError) as error: - NearVideo({"video": b"True"}) - check_error_message(self, error, media_value_error_msg(bytes)) - - with self.assertRaises(TypeError) as error: - NearVideo({"video": "the_encoded_video", "certainty": False}) - check_error_message(self, error, certainty_error_msg(bool)) - - # valid calls - - NearVideo( - { - "video": "test_video", - } - ) - - NearVideo({"video": "test_video_2", "certainty": 0.7}) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_object = NearVideo( - { - "video": "test_video", - } - ) - self.assertEqual(str(near_object), 'nearVideo: {video: "test_video"} ') - - near_object = NearVideo({"video": "test_video", "certainty": 0.7}) - self.assertEqual(str(near_object), 'nearVideo: {video: "test_video" certainty: 0.7} ') - - -class TestNearAudio(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - - ## error messages - media_type = MediaType.AUDIO - arg_name = media_type.value.capitalize() - content_error_msg = ( - lambda dt: f"Near{arg_name} filter is expected to be type dict but is {dt}" - ) - media_key_error_msg = f'"content" is missing the mandatory key "{media_type.value}"!' - media_value_error_msg = ( - lambda dt: f"'{media_type.value}' key-value is expected to be of type but is {dt}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - with self.assertRaises(TypeError) as error: - NearAudio(123) - check_error_message(self, error, content_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearAudio({"id": "audio_path.wav", "certainty": 456}) - check_error_message(self, error, media_key_error_msg) - - with self.assertRaises(TypeError) as error: - NearAudio({"audio": True}) - check_error_message(self, error, media_value_error_msg(bool)) - - with self.assertRaises(TypeError) as error: - NearAudio({"audio": b"True"}) - check_error_message(self, error, media_value_error_msg(bytes)) - - with self.assertRaises(TypeError) as error: - NearAudio({"audio": "the_encoded_audio", "certainty": False}) - check_error_message(self, error, certainty_error_msg(bool)) - - # valid calls - - NearAudio( - { - "audio": "test_audio", - } - ) - - NearAudio({"audio": "test_audio_2", "certainty": 0.7}) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_object = NearAudio( - { - "audio": "test_audio", - } - ) - self.assertEqual(str(near_object), 'nearAudio: {audio: "test_audio"} ') - - near_object = NearAudio({"audio": "test_audio", "certainty": 0.7}) - self.assertEqual(str(near_object), 'nearAudio: {audio: "test_audio" certainty: 0.7} ') - - -class TestNearDepth(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - - ## error messages - media_type = MediaType.DEPTH - arg_name = media_type.value.capitalize() - content_error_msg = ( - lambda dt: f"Near{arg_name} filter is expected to be type dict but is {dt}" - ) - media_key_error_msg = f'"content" is missing the mandatory key "{media_type.value}"!' - media_value_error_msg = ( - lambda dt: f"'{media_type.value}' key-value is expected to be of type but is {dt}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - with self.assertRaises(TypeError) as error: - NearDepth(123) - check_error_message(self, error, content_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearDepth({"id": "depth_path.png", "certainty": 456}) - check_error_message(self, error, media_key_error_msg) - - with self.assertRaises(TypeError) as error: - NearDepth({"depth": True}) - check_error_message(self, error, media_value_error_msg(bool)) - - with self.assertRaises(TypeError) as error: - NearDepth({"depth": b"True"}) - check_error_message(self, error, media_value_error_msg(bytes)) - - with self.assertRaises(TypeError) as error: - NearDepth({"depth": "the_encoded_depth", "certainty": False}) - check_error_message(self, error, certainty_error_msg(bool)) - - # valid calls - - NearDepth( - { - "depth": "test_depth", - } - ) - - NearDepth({"depth": "test_depth_2", "certainty": 0.7}) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_object = NearDepth( - { - "depth": "test_depth", - } - ) - self.assertEqual(str(near_object), 'nearDepth: {depth: "test_depth"} ') - - near_object = NearDepth({"depth": "test_depth", "certainty": 0.7}) - self.assertEqual(str(near_object), 'nearDepth: {depth: "test_depth" certainty: 0.7} ') - - -class TestNearThermal(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - - ## error messages - media_type = MediaType.THERMAL - arg_name = media_type.value.capitalize() - content_error_msg = ( - lambda dt: f"Near{arg_name} filter is expected to be type dict but is {dt}" - ) - media_key_error_msg = f'"content" is missing the mandatory key "{media_type.value}"!' - media_value_error_msg = ( - lambda dt: f"'{media_type.value}' key-value is expected to be of type but is {dt}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - with self.assertRaises(TypeError) as error: - NearThermal(123) - check_error_message(self, error, content_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearThermal({"id": "thermal_path.png", "certainty": 456}) - check_error_message(self, error, media_key_error_msg) - - with self.assertRaises(TypeError) as error: - NearThermal({"thermal": True}) - check_error_message(self, error, media_value_error_msg(bool)) - - with self.assertRaises(TypeError) as error: - NearThermal({"thermal": b"True"}) - check_error_message(self, error, media_value_error_msg(bytes)) - - with self.assertRaises(TypeError) as error: - NearThermal({"thermal": "the_encoded_thermal", "certainty": False}) - check_error_message(self, error, certainty_error_msg(bool)) - - # valid calls - - NearThermal( - { - "thermal": "test_thermal", - } - ) - - NearThermal({"thermal": "test_thermal_2", "certainty": 0.7}) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_object = NearThermal( - { - "thermal": "test_thermal", - } - ) - self.assertEqual(str(near_object), 'nearThermal: {thermal: "test_thermal"} ') - - near_object = NearThermal({"thermal": "test_thermal", "certainty": 0.7}) - self.assertEqual(str(near_object), 'nearThermal: {thermal: "test_thermal" certainty: 0.7} ') - - -class TestNearIMU(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - # invalid calls - - ## error messages - media_type = MediaType.IMU - arg_name = media_type.value.upper() - content_error_msg = ( - lambda dt: f"Near{arg_name} filter is expected to be type dict but is {dt}" - ) - media_key_error_msg = f'"content" is missing the mandatory key "{media_type.value}"!' - media_value_error_msg = ( - lambda dt: f"'{media_type.value}' key-value is expected to be of type but is {dt}!" - ) - certainty_error_msg = lambda dtype: ( - f"'certainty' key-value is expected to be of type but is {dtype}!" - ) - - with self.assertRaises(TypeError) as error: - NearIMU(123) - check_error_message(self, error, content_error_msg(int)) - - with self.assertRaises(ValueError) as error: - NearIMU({"id": "imu_path.txt", "certainty": 456}) - check_error_message(self, error, media_key_error_msg) - - with self.assertRaises(TypeError) as error: - NearIMU({"imu": True}) - check_error_message(self, error, media_value_error_msg(bool)) - - with self.assertRaises(TypeError) as error: - NearIMU({"imu": b"True"}) - check_error_message(self, error, media_value_error_msg(bytes)) - - with self.assertRaises(TypeError) as error: - NearIMU({"imu": "the_encoded_imu", "certainty": False}) - check_error_message(self, error, certainty_error_msg(bool)) - - # valid calls - - NearIMU( - { - "imu": "test_imu", - } - ) - - NearIMU({"imu": "test_imu_2", "certainty": 0.7}) - - def test___str__(self): - """ - Test the `__str__` method. - """ - - near_object = NearIMU( - { - "imu": "test_imu", - } - ) - self.assertEqual(str(near_object), 'nearIMU: {imu: "test_imu"} ') - - near_object = NearIMU({"imu": "test_imu", "certainty": 0.7}) - self.assertEqual(str(near_object), 'nearIMU: {imu: "test_imu" certainty: 0.7} ') diff --git a/test/gql/test_get.py b/test/gql/test_get.py deleted file mode 100644 index bfc9bfcd2..000000000 --- a/test/gql/test_get.py +++ /dev/null @@ -1,883 +0,0 @@ -import unittest -from typing import List, Optional, Callable, Tuple -from unittest.mock import patch, Mock - -import pytest - -from test.util import check_error_message -from weaviate.data.replication import ConsistencyLevel -from weaviate.gql.get import ( - GetBuilder, - BM25, - Hybrid, - LinkTo, - GroupBy, - AdditionalProperties, - HybridFusion, -) - -mock_connection_v117 = Mock() -mock_connection_v117.server_version = "1.17.4" - - -@pytest.mark.parametrize( - "props,expected", - [ - (AdditionalProperties(uuid=True), " _additional{id} "), - ( - AdditionalProperties(uuid=True, vector=True, explainScore=True), - " _additional{id vector explainScore} ", - ), - ( - AdditionalProperties(uuid=False, vector=True, explainScore=True, score=True), - " _additional{vector score explainScore} ", - ), - ], -) -def test_additional_props(props: AdditionalProperties, expected: str): - assert str(props) == expected - - -@pytest.mark.parametrize( - "additional_props,expected", - [ - (AdditionalProperties(uuid=True), " _additional{id} "), - ], -) -def test_getbuilder_with_additional_props(additional_props: AdditionalProperties, expected: str): - query = ( - GetBuilder("TestClass", "name", mock_connection_v117) - .with_additional(additional_props) - .build() - ) - expected_query = "{Get{TestClass{name" + expected + "}}}" - assert ( - "name_" not in expected_query - ) # Check that the prop name is not being concatnated to _additional - assert query == expected_query - - -@pytest.mark.parametrize( - "query,properties,expected", - [ - ( - "query", - ["title", "document", "date"], - 'bm25:{query: "query", properties: ["title","document","date"]}', - ), - ("other query", [], 'bm25:{query: "other query"}'), - ("other query", None, 'bm25:{query: "other query"}'), - ('what is an "airport"', None, 'bm25:{query: "what is an \\"airport\\""}'), - ("what is an 'airport'", None, """bm25:{query: "what is an 'airport'"}"""), - ], -) -def test_bm25(query: str, properties: List[str], expected: str): - bm25 = BM25(query, properties) - assert str(bm25) == expected - - -@pytest.mark.parametrize( - "property_name,in_class,properties,expected", - [ - ( - "property", - "class", - ["title"], - "property{... on class{title}}", - ), - ( - "property", - "class", - ["title", "document", "date"], - "property{... on class{title document date}}", - ), - ], -) -def test_get_references(property_name: str, in_class: str, properties: List[str], expected: str): - ref = LinkTo(property_name, in_class, properties) - assert str(ref) == expected - - -@pytest.mark.parametrize( - "query,vector,alpha,properties,fusion_type,expected", - [ - ( - "query", - [1, 2, 3], - 0.5, - None, - None, - 'hybrid:{query: "query", vector: [1, 2, 3], alpha: 0.5}', - ), - ("query", None, None, None, None, 'hybrid:{query: "query"}'), - ('query "query2"', None, None, None, None, 'hybrid:{query: "query \\"query2\\""}'), - ("query 'query2'", None, None, None, None, """hybrid:{query: "query 'query2'"}"""), - ("query", None, None, ["prop1"], None, 'hybrid:{query: "query", properties: ["prop1"]}'), - ( - "query", - None, - None, - ["prop1", "prop2"], - None, - 'hybrid:{query: "query", properties: ["prop1","prop2"]}', - ), - ( - "query", - None, - None, - None, - HybridFusion.RANKED, - 'hybrid:{query: "query", fusionType: rankedFusion}', - ), - ( - "query", - None, - None, - None, - HybridFusion.RELATIVE_SCORE, - 'hybrid:{query: "query", fusionType: relativeScoreFusion}', - ), - ( - "query", - None, - None, - None, - "relativeScoreFusion", - 'hybrid:{query: "query", fusionType: relativeScoreFusion}', - ), - ], -) -def test_hybrid( - query: str, - vector: Optional[List[float]], - alpha: Optional[float], - properties: Optional[List[str]], - fusion_type: HybridFusion, - expected: str, -): - hybrid = Hybrid(query, alpha, vector, properties, fusion_type) - assert str(hybrid) == expected - - -@pytest.mark.parametrize( - "single_prompt,grouped_task,grouped_properties,expected", - [ - ( - "What is the meaning of life?", - None, - None, - """generate(singleResult:{prompt:"What is the meaning of life?"}){error singleResult} """, - ), - ( - 'What is the meaning of "life"?', - None, - None, - """generate(singleResult:{prompt:"What is the meaning of \\"life\\"?"}){error singleResult} """, - ), - ( - "What is the meaning of 'life'?", - None, - None, - """generate(singleResult:{prompt:"What is the meaning of 'life'?"}){error singleResult} """, - ), - ( - None, - "Explain why these magazines or newspapers are about finance", - None, - """generate(groupedResult:{task:"Explain why these magazines or newspapers are about finance"}){error groupedResult} """, - ), - ( - None, - 'Explain why these magazines or newspapers are about "finance"', - None, - """generate(groupedResult:{task:"Explain why these magazines or newspapers are about \\"finance\\""}){error groupedResult} """, - ), - ( - None, - "Explain why these magazines or newspapers are about 'finance'", - None, - """generate(groupedResult:{task:"Explain why these magazines or newspapers are about 'finance'"}){error groupedResult} """, - ), - ( - "What is the meaning of life?", - "Explain why these magazines or newspapers are about finance", - None, - """generate(singleResult:{prompt:"What is the meaning of life?"}groupedResult:{task:"Explain why these magazines or newspapers are about finance"}){error singleResult groupedResult} """, - ), - ( - None, - "Explain why these magazines or newspapers are about finance", - ["description"], - """generate(groupedResult:{task:"Explain why these magazines or newspapers are about finance",properties:["description"]}){error groupedResult} """, - ), - ( - "What is the meaning of life?", - "Explain why these magazines or newspapers are about finance", - ["title", "description"], - """generate(singleResult:{prompt:"What is the meaning of life?"}groupedResult:{task:"Explain why these magazines or newspapers are about finance",properties:["title","description"]}){error singleResult groupedResult} """, - ), - ], -) -def test_generative( - single_prompt: str, grouped_task: str, grouped_properties: List[str], expected: str -): - query = ( - GetBuilder("Person", "name", mock_connection_v117) - .with_generate(single_prompt, grouped_task, grouped_properties) - .build() - ) - expected_query = "{Get{Person{name _additional {" + expected + "}}}}" - assert query == expected_query - - -@pytest.mark.parametrize("single_prompt,grouped_task", [(123, None), (None, None), (None, 123)]) -def test_generative_type(single_prompt: str, grouped_task: str): - with pytest.raises(TypeError): - GetBuilder("Person", "name", mock_connection_v117).with_generate( - single_prompt, grouped_task - ).build() - - -@pytest.mark.parametrize( - "properties,groups,max_groups,expected", - [ - ( - ["prop1", "prop2"], - 2, - 3, - 'groupBy:{path:["prop1","prop2"], groups:2, objectsPerGroup:3}', - ), - ( - ["prop1"], - 4, - 5, - 'groupBy:{path:["prop1"], groups:4, objectsPerGroup:5}', - ), - ], -) -def test_groupy(properties: List[str], groups: int, max_groups: int, expected: str): - assert str(GroupBy(properties, groups, max_groups)) == expected - - -class TestGetBuilder(unittest.TestCase): - def test___init__(self): - """ - Test the `__init__` method. - """ - - class_name_error_msg = f"class name must be of type str but was {int}" - properties_error_msg = ( - "properties must be of type str, " f"list of str or None but was {int}" - ) - property_error_msg = "All the `properties` must be of type `str` or Reference!" - - # invalid calls - with self.assertRaises(TypeError) as error: - GetBuilder(1, ["a"], None) - check_error_message(self, error, class_name_error_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("A", 2, None) - check_error_message(self, error, properties_error_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("A", [True], None) - check_error_message(self, error, property_error_msg) - - # valid calls - GetBuilder("name", "prop", None) - GetBuilder("name", ["prop1", "prop2"], None) - - def test_build_with_limit(self): - """ - Test the `with_limit` method. - """ - - # valid calls - query = GetBuilder("Person", "name", None).with_limit(20).build() - self.assertEqual("{Get{Person(limit: 20 ){name}}}", query) - - # invalid calls - limit_error_msg = "limit cannot be non-positive (limit >=1)." - with self.assertRaises(ValueError) as error: - GetBuilder("A", ["str"], None).with_limit(0) - check_error_message(self, error, limit_error_msg) - - with self.assertRaises(ValueError) as error: - GetBuilder("A", ["str"], None).with_limit(-1) - check_error_message(self, error, limit_error_msg) - - def test_build_with_offset(self): - """ - Test the `with_limit` method. - """ - - # valid calls - query = GetBuilder("Person", "name", None).with_offset(20).build() - self.assertEqual("{Get{Person(offset: 20 ){name}}}", query) - - query = GetBuilder("Person", "name", None).with_offset(0).build() - self.assertEqual("{Get{Person(offset: 0 ){name}}}", query) - - # invalid calls - limit_error_msg = "offset cannot be non-positive (offset >=0)." - with self.assertRaises(ValueError) as error: - GetBuilder("A", ["str"], None).with_offset(-1) - check_error_message(self, error, limit_error_msg) - - def test_build_with_consistency_level(self): - """ - Test the `with_consistency_level` method - """ - - query = ( - GetBuilder("Person", "name", None).with_consistency_level(ConsistencyLevel.ONE).build() - ) - self.assertEqual("{Get{Person(consistencyLevel: ONE ){name}}}", query) - - query = ( - GetBuilder("Person", "name", None) - .with_consistency_level(ConsistencyLevel.QUORUM) - .build() - ) - self.assertEqual("{Get{Person(consistencyLevel: QUORUM ){name}}}", query) - - query = ( - GetBuilder("Person", "name", None).with_consistency_level(ConsistencyLevel.ALL).build() - ) - self.assertEqual("{Get{Person(consistencyLevel: ALL ){name}}}", query) - - def test_build_with_where(self): - """ - Test the ` with_where` method. - """ - - filter_name = {"path": ["name"], "operator": "Equal", "valueString": "A"} - query = GetBuilder("Person", "name", None).with_where(filter_name).build() - self.assertEqual( - '{Get{Person(where: {path: ["name"] operator: Equal valueString: "A"} ){name}}}', query - ) - - def test_build_with_near_text(self): - """ - Test the `with_near_text` method. - """ - - near_text = { - "concepts": "computer", - "moveTo": {"concepts": ["science"], "force": 0.5}, - "autocorrect": True, - } - - # valid calls - query = GetBuilder("Person", "name", None).with_near_text(near_text).build() - self.assertEqual( - '{Get{Person(nearText: {concepts: ["computer"] moveTo: {force: 0.5 concepts: ["science"]} autocorrect: true} ){name}}}', - query, - ) - - # invalid calls - near_error_msg = ( - "Cannot use multiple 'near' filters, or a 'near' filter along with a 'ask' filter!" - ) - - near_vector = {"vector": [1, 2, 3, 4, 5, 6, 7, 8, 9], "certainty": 0.55} - with self.assertRaises(AttributeError) as error: - GetBuilder("Person", "name", None).with_near_vector(near_vector).with_near_text( - near_text - ) - check_error_message(self, error, near_error_msg) - - def test_build_near_vector(self): - """ - Test the `with_near_vector` method. - """ - - near_vector = {"vector": [1, 2, 3, 4, 5, 6, 7, 8, 9], "certainty": 0.55} - - # valid calls - mock_connection = Mock() - mock_connection.server_version = "1.14.0" - query = GetBuilder("Person", "name", mock_connection).with_near_vector(near_vector).build() - self.assertEqual( - "{Get{Person(nearVector: {vector: [1, 2, 3, 4, 5, 6, 7, 8, 9] certainty: 0.55} ){name}}}", - query, - ) - - # invalid calls - near_error_msg = ( - "Cannot use multiple 'near' filters, or a 'near' filter along with a 'ask' filter!" - ) - - near_object = {"id": "test_id", "certainty": 0.55} - with self.assertRaises(AttributeError) as error: - GetBuilder("Person", "name", mock_connection).with_near_object( - near_object - ).with_near_vector(near_vector) - check_error_message(self, error, near_error_msg) - - def test_build_near_object(self): - """ - Test the `with_near_object` method. - """ - - near_object = {"id": "test_id", "certainty": 0.55} - - # valid calls - mock_connection = Mock() - mock_connection.server_version = "1.14.0" - query = GetBuilder("Person", "name", mock_connection).with_near_object(near_object).build() - self.assertEqual('{Get{Person(nearObject: {id: "test_id" certainty: 0.55} ){name}}}', query) - - # invalid calls - near_error_msg = ( - "Cannot use multiple 'near' filters, or a 'near' filter along with a 'ask' filter!" - ) - - near_text = { - "concepts": "computer", - "moveTo": {"concepts": ["science"], "force": 0.5}, - } - with self.assertRaises(AttributeError) as error: - GetBuilder("Person", "name", mock_connection).with_near_text( - near_text - ).with_near_object(near_object) - check_error_message(self, error, near_error_msg) - - @patch("weaviate.gql.get.image_encoder_b64", side_effect=lambda x: "test_call") - def test_build_near_image(self, mock_image_encoder_b64: Mock): - """ - Test the `with_near_object` method. - """ - - near_image = {"image": "test_image", "certainty": 0.55} - - # valid calls - ## encode False - query = GetBuilder("Person", "name", None).with_near_image(near_image, encode=False).build() - self.assertEqual( - '{Get{Person(nearImage: {image: "test_image" certainty: 0.55} ){name}}}', query - ) - mock_image_encoder_b64.assert_not_called() - - ## encode True - query = GetBuilder("Person", "name", None).with_near_image(near_image, encode=True).build() - self.assertEqual( - '{Get{Person(nearImage: {image: "test_call" certainty: 0.55} ){name}}}', query - ) - mock_image_encoder_b64.assert_called() - - # invalid calls - near_error_msg = ( - "Cannot use multiple 'near' filters, or a 'near' filter along with a 'ask' filter!" - ) - - near_text = { - "concepts": "computer", - "moveTo": {"concepts": ["science"], "force": 0.5}, - } - with self.assertRaises(AttributeError) as error: - GetBuilder("Person", "name", None).with_near_text(near_text).with_near_image(near_image) - check_error_message(self, error, near_error_msg) - - test_build_near_media_param_list: List[ - Tuple[str, str, str, Callable[[GetBuilder, str, str, bool], GetBuilder]] - ] = [ - ( - "audio", - "test_audio", - "nearAudio", - lambda b, k, v, e: b.with_near_audio({k: v, "certainty": 0.55}, encode=e), - ), - ( - "video", - "test_video", - "nearVideo", - lambda b, k, v, e: b.with_near_video({k: v, "certainty": 0.55}, encode=e), - ), - ( - "depth", - "test_depth", - "nearDepth", - lambda b, k, v, e: b.with_near_depth({k: v, "certainty": 0.55}, encode=e), - ), - ( - "thermal", - "test_thermal", - "nearThermal", - lambda b, k, v, e: b.with_near_thermal({k: v, "certainty": 0.55}, encode=e), - ), - ( - "imu", - "test_imu", - "nearIMU", - lambda b, k, v, e: b.with_near_imu({k: v, "certainty": 0.55}, encode=e), - ), - ] - - def test_build_near_media(self): - """ - Test the `with_near_` method. - """ - for key, value, type_, fn in self.test_build_near_media_param_list: - with self.subTest(key=key, value=value): - with patch( - "weaviate.gql.get.file_encoder_b64", side_effect=lambda x: "test_call" - ) as mocked: - # valid calls - ## encode False - query = fn(GetBuilder("Person", "name", None), key, value, False).build() - self.assertEqual( - f'{{Get{{Person({type_}: {{{key}: "{value}" certainty: 0.55}} ){{name}}}}}}', - query, - ) - mocked.assert_not_called() - - ## encode True - query = fn(GetBuilder("Person", "name", None), key, value, True).build() - self.assertEqual( - f'{{Get{{Person({type_}: {{{key}: "test_call" certainty: 0.55}} ){{name}}}}}}', - query, - ) - mocked.assert_called() - - # invalid calls - near_error_msg = "Cannot use multiple 'near' filters, or a 'near' filter along with a 'ask' filter!" - - near_text = { - "concepts": "computer", - "moveTo": {"concepts": ["science"], "force": 0.5}, - } - with self.assertRaises(AttributeError) as error: - fn( - GetBuilder("Person", "name", None).with_near_text(near_text), - key, - value, - True, - ) - check_error_message(self, error, near_error_msg) - - def test_build_ask(self): - """ - Test the `with_ask` method. - """ - - ask = { - "question": "What is k8s?", - "certainty": 0.55, - "autocorrect": False, - } - - # valid calls - query = GetBuilder("Person", "name", None).with_ask(ask).build() - self.assertEqual( - '{Get{Person(ask: {question: "What is k8s?" certainty: 0.55 autocorrect: false} ){name}}}', - query, - ) - - # invalid calls - near_error_msg = ( - "Cannot use multiple 'near' filters, or a 'near' filter along with a 'ask' filter!" - ) - - near_text = { - "concepts": "computer", - "moveTo": {"concepts": ["science"], "force": 0.5}, - } - with self.assertRaises(AttributeError) as error: - GetBuilder("Person", "name", None).with_near_text(near_text).with_ask(ask) - check_error_message(self, error, near_error_msg) - - def test_build_with_additional(self): - """ - Test the `with_additional` method. - """ - - # valid calls - ## `str` as argument - query = GetBuilder("Person", "name", None).with_additional("id").build() - self.assertEqual("{Get{Person{name _additional {id }}}}", query) - - ## list of `str` as argument - query = ( - GetBuilder("Person", "name", None).with_additional(["id", "certainty", "test"]).build() - ) - self.assertEqual("{Get{Person{name _additional {certainty id test }}}}", query) - - ## dict with value `str` as argument - query = GetBuilder("Person", "name", None).with_additional({"classification": "id"}).build() - self.assertEqual("{Get{Person{name _additional {classification {id } }}}}", query) - - ## dict with value list of `str` as argument - query = ( - GetBuilder("Person", "name", None) - .with_additional({"classification": ["basedOn", "classifiedFields", "completed", "id"]}) - .build() - ) - self.assertEqual( - "{Get{Person{name _additional {classification {basedOn classifiedFields completed id } }}}}", - query, - ) - - ## dict with value list of `tuple` as argument - clause = {"token": ["entity", "word"]} - settings = {"test1": 1, "test3": [True], "test2": 10.0} - query = GetBuilder("Person", "name", None).with_additional((clause, settings)).build() - self.assertEqual( - "{Get{Person{name _additional {token(test1: 1 test2: 10.0 test3: [true] ) {entity word } }}}}", - query, - ) - - ## dict with value list of `tuple` as argument - clause = {"token": "certainty"} - settings = {"test1": ["TEST"]} - query = GetBuilder("Person", "name", None).with_additional((clause, settings)).build() - self.assertEqual( - '{Get{Person{name _additional {token(test1: ["TEST"] ) {certainty } }}}}', query - ) - - ## multiple calls - clause = {"token": "certainty"} - settings = {"test1": ["TEST"]} - query = ( - GetBuilder("Person", "name", None) - .with_additional("test") - .with_additional(["id", "certainty"]) - .with_additional({"classification": ["completed", "id"]}) - .with_additional((clause, settings)) - .build() - ) - self.assertEqual( - '{Get{Person{name _additional {certainty id test classification {completed id } token(test1: ["TEST"] ) {certainty } }}}}', - query, - ) - - ## multiple calls - query = ( - GetBuilder("Person", None, None) - .with_additional("test") - .with_additional(["id", "certainty"]) - .with_additional({"classification": ["completed", "id"]}) - .with_additional("id") - .with_additional("test") - .build() - ) - self.assertEqual( - "{Get{Person{ _additional {certainty id test classification {completed id } }}}}", query - ) - - # invalid calls - # error messages - prop_type_msg = lambda dt: ( - "The 'properties' argument must be either of type `str`, `list`, `dict` or `tuple`! " - f"Given: {dt}" - ) - prop_list_msg = "If type of 'properties' is `list` then all items must be of type `str`!" - prop_dict_key_msg = "If type of 'properties' is `dict` then all keys must be of type `str`!" - prop_dict_value_msg = lambda dt: ( - "If type of 'properties' is `dict` then all the values must be either of type " - f"`str` or `list` of `str`! Given: {dt}!" - ) - prop_dict_value_len = ( - "If type of 'properties' is `dict` and a value is of type `list` then at least" - " one element should be present!" - ) - prop_dict_value_item_msg = ( - "If type of 'properties' is `dict` and a value is of type `list` then all " - "items must be of type `str`!" - ) - prop_tuple_len_msg = ( - "If type of 'properties' is `tuple` then it should have length 2: " - "(clause: , settings: )" - ) - prop_tuple_type_msg = ( - "If type of 'properties' is `tuple` then it should have this data type: " - "(, )" - ) - prop_tuple_clause_len_msg = lambda clause: ( - "If type of 'properties' is `tuple` then the 'clause' (first element) should " - f"have only one key. Given: {len(clause)}" - ) - prop_tuple_settings_len_msg = lambda settings: ( - "If type of 'properties' is `tuple` then the 'settings' (second element) should " - f"have at least one key. Given: {len(settings)}" - ) - prop_tuple_clause_key_type_msg = ( - "If type of 'properties' is `tuple` then first element's key should be of type " - "`str`!" - ) - prop_tuple_settings_keys_type_msg = ( - "If type of 'properties' is `tuple` then the second elements () should " - "have all the keys of type `str`!" - ) - prop_tuple_clause_value_type_msg = lambda dt: ( - "If type of 'properties' is `tuple` then first element's dict values must be " - f"either of type `str` or `list` of `str`! Given: {dt}!" - ) - prop_tuple_clause_value_len_msg = ( - "If type of 'properties' is `tuple` and first element's dict value is of type " - "`list` then at least one element should be present!" - ) - prop_tuple_clause_values_items_type_msg = ( - "If type of 'properties' is `tuple` and first element's dict value is of type " - " `list` then all items must be of type `str`!" - ) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional(123) - check_error_message(self, error, prop_type_msg(int)) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional([123]) - check_error_message(self, error, prop_list_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional({123: "Test"}) - check_error_message(self, error, prop_dict_key_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional({"test": True}) - check_error_message(self, error, prop_dict_value_msg(bool)) - - with self.assertRaises(ValueError) as error: - GetBuilder("Person", "name", None).with_additional({"test": []}) - check_error_message(self, error, prop_dict_value_len) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional({"test": [True]}) - check_error_message(self, error, prop_dict_value_item_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional({"test": [True]}) - check_error_message(self, error, prop_dict_value_item_msg) - - with self.assertRaises(ValueError) as error: - GetBuilder("Person", "name", None).with_additional((1,)) - check_error_message(self, error, prop_tuple_len_msg) - - with self.assertRaises(ValueError) as error: - GetBuilder("Person", "name", None).with_additional((1, 2, 3)) - check_error_message(self, error, prop_tuple_len_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional(({1: "1"}, ["test"])) - check_error_message(self, error, prop_tuple_type_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional(([{1: "1"}], ["test"])) - check_error_message(self, error, prop_tuple_type_msg) - - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional((["test"], {1: "1"})) - check_error_message(self, error, prop_tuple_type_msg) - - clause = {"test1": 1, "test2": 2} - with self.assertRaises(ValueError) as error: - GetBuilder("Person", "name", None).with_additional((clause, {"1": "1"})) - check_error_message(self, error, prop_tuple_clause_len_msg(clause)) - - clause = {"test1": "1"} - settings = {} - with self.assertRaises(ValueError) as error: - GetBuilder("Person", "name", None).with_additional((clause, settings)) - check_error_message(self, error, prop_tuple_settings_len_msg(settings)) - - clause = {1: "1"} - settings = {"test": 1} - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional((clause, settings)) - check_error_message(self, error, prop_tuple_clause_key_type_msg) - - clause = {"test": "1"} - settings = {"test": 1, 2: 2} - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional((clause, settings)) - check_error_message(self, error, prop_tuple_settings_keys_type_msg) - - clause = {"test": "1"} - settings = {2: 2} - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional((clause, settings)) - check_error_message(self, error, prop_tuple_settings_keys_type_msg) - - clause = {"test": True} - settings = {"test": 2} - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional((clause, settings)) - check_error_message(self, error, prop_tuple_clause_value_type_msg(bool)) - - clause = {"test": []} - settings = {"test": 2} - with self.assertRaises(ValueError) as error: - GetBuilder("Person", "name", None).with_additional((clause, settings)) - check_error_message(self, error, prop_tuple_clause_value_len_msg) - - clause = {"test": ["1", "2", 3]} - settings = {"test": 2} - with self.assertRaises(TypeError) as error: - GetBuilder("Person", "name", None).with_additional((clause, settings)) - check_error_message(self, error, prop_tuple_clause_values_items_type_msg) - - def test_build(self): - """ - Test the `build` method. (without filters) - """ - - error_message = ( - "No 'properties' or 'additional properties' specified to be returned. " - "At least one should be included." - ) - - with self.assertRaises(AttributeError) as error: - query = GetBuilder("Group", [], None).build() - check_error_message(self, error, error_message) - - query = GetBuilder("Group", "name", None).build() - self.assertEqual("{Get{Group{name}}}", query) - - query = GetBuilder("Group", ["name", "uuid"], None).build() - self.assertEqual("{Get{Group{name uuid}}}", query) - - query = GetBuilder("Group", None, None).with_additional("distance").build() - self.assertEqual("{Get{Group{ _additional {distance }}}}", query) - - near_text = { - "concepts": ["computer"], - "moveTo": {"concepts": "science", "force": 0.1}, - "moveAwayFrom": {"concepts": ["airplane"], "force": 0.2}, - "certainty": 0.3, - } - or_filter = { - "operator": "Or", - "operands": [ - { - "path": ["name"], - "operator": "Equal", - "valueString": "Alan Turing", - }, - {"path": ["name"], "operator": "Equal", "valueString": "John von Neumann"}, - ], - } - query = ( - GetBuilder("Person", ["name", "uuid"], None) - .with_near_text(near_text) - .with_where(or_filter) - .with_limit(2) - .with_offset(10) - .build() - ) - self.assertEqual( - '{Get{Person(where: {operator: Or operands: [{path: ["name"] operator: Equal valueString: "Alan Turing"}, {path: ["name"] operator: Equal valueString: "John von Neumann"}]} limit: 2 offset: 10 nearText: {concepts: ["computer"] certainty: 0.3 moveTo: {force: 0.1 concepts: ["science"]} moveAwayFrom: {force: 0.2 concepts: ["airplane"]}} ){name uuid}}}', - query, - ) - - def test_capitalized_class_name(self): - """ - Test the capitalized class_name. - """ - - get = GetBuilder("Test", ["prop"], None) - self.assertEqual(get._class_name, "Test") - - get = GetBuilder("test", ["prop"], None) - self.assertEqual(get._class_name, "Test") diff --git a/test/gql/test_query.py b/test/gql/test_query.py deleted file mode 100644 index f009332a8..000000000 --- a/test/gql/test_query.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest -from unittest.mock import Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.gql.query import Query - - -class TestQuery(unittest.TestCase): - def test_get(self): - """ - Test the `get` attribute. - """ - - query = Query(Mock()) - - gql = query.get("Group", ["name", "uuid"]).build() - self.assertEqual("{Get{Group{name uuid}}}", gql) - - def test_aggregate(self): - """ - Test the `aggregate` attribute. - """ - - query = Query(Mock()) - - gql = query.aggregate("Group").build() - self.assertEqual("{Aggregate{Group{}}}", gql) - - def test_raw(self): - """ - Test the `raw` method. - """ - - # valid calls - connection_mock = mock_connection_func("post", return_json={}) - query = Query(connection_mock) - - gql_query = "{Get {Group {name Members {... on Person {name}}}}}" - query.raw(gql_query) - - connection_mock.post.assert_called_with( - path="/graphql", weaviate_object={"query": gql_query} - ) - - # invalid calls - - type_error_message = "Query is expected to be a string" - requests_error_message = "Query not executed." - query_error_message = "GQL query failed" - - with self.assertRaises(TypeError) as error: - query.raw(["TestQuery"]) - check_error_message(self, error, type_error_message) - - query = Query(mock_connection_func("post", side_effect=RequestsConnectionError("Test!"))) - with self.assertRaises(RequestsConnectionError) as error: - query.raw("TestQuery") - check_error_message(self, error, requests_error_message) - - query = Query(mock_connection_func("post", status_code=404)) - with self.assertRaises(UnexpectedStatusCodeException) as error: - query.raw("TestQuery") - check_startswith_error_message(self, error, query_error_message) diff --git a/test/schema/__init__.py b/test/schema/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/schema/properties/__init__.py b/test/schema/properties/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/schema/properties/test_properties.py b/test/schema/properties/test_properties.py deleted file mode 100644 index f742b0392..000000000 --- a/test/schema/properties/test_properties.py +++ /dev/null @@ -1,63 +0,0 @@ -import unittest -from unittest.mock import Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.exceptions import ( - UnexpectedStatusCodeException, -) -from weaviate.schema.properties import Property - - -class TestCRUDProperty(unittest.TestCase): - def test_create(self): - """ - Test `create` method. - """ - - prop = Property(Mock()) - - # invalid calls - error_message = "Class name must be of type str but is " - requests_error_message = "Property was created properly." - - with self.assertRaises(TypeError) as error: - prop.create(35, {}) - check_error_message(self, error, error_message + str(int)) - - prop = Property(mock_connection_func("post", side_effect=RequestsConnectionError("Test!"))) - with self.assertRaises(RequestsConnectionError) as error: - prop.create("Class", {"name": "test", "dataType": ["test_type"]}) - check_error_message(self, error, requests_error_message) - - prop = Property(mock_connection_func("post", status_code=404)) - with self.assertRaises(UnexpectedStatusCodeException) as error: - prop.create("Class", {"name": "test", "dataType": ["test_type"]}) - check_startswith_error_message(self, error, "Add property to class") - - # valid calls - connection_mock = mock_connection_func("post") # Mock calling weaviate - prop = Property(connection_mock) - - test_prop = { - "dataType": ["string"], - "description": "my Property", - "moduleConfig": {"text2vec-contextionary": {"vectorizePropertyName": True}}, - "name": "superProp", - "indexInverted": True, - } - - prop.create("TestThing", test_prop) - - connection_mock.post.assert_called_with( - path="/schema/TestThing/properties", - weaviate_object=test_prop, - ) - - prop.create("testThing", test_prop) - - connection_mock.post.assert_called_with( - path="/schema/TestThing/properties", - weaviate_object=test_prop, - ) diff --git a/test/schema/schema_company.json b/test/schema/schema_company.json deleted file mode 100644 index f82099ddb..000000000 --- a/test/schema/schema_company.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "classes": [ - { - "class": "Company", - "description": "A business that acts in the market", - "properties": [ - { - "name": "name", - "description": "The name under which the company is known", - "dataType": ["text"] - }, - { - "name": "legalBody", - "description": "The legal body under which the company maintains its business", - "dataType": ["text"] - }, - { - "name": "hasEmployee", - "description": "The employees of the company", - "dataType": ["Employee"] - } - ] - }, - { - "class": "Employee", - "description": "An employee of the company", - "properties": [ - { - "name": "name", - "description": "The name of the employee", - "dataType": ["text"] - }, - { - "name": "job", - "description": "the job description of the employee", - "dataType": ["text"] - }, - { - "name": "yearsInTheCompany", - "description": "The number of years this employee has worked in the company", - "dataType": ["int"] - } - ] - } - ] -} diff --git a/test/schema/tenants.json b/test/schema/tenants.json deleted file mode 100644 index 6594674eb..000000000 --- a/test/schema/tenants.json +++ /dev/null @@ -1,8 +0,0 @@ -[ - { - "name": "Tenant1" - }, - { - "name": "Tenant2" - } -] \ No newline at end of file diff --git a/test/schema/test_schema.py b/test/schema/test_schema.py deleted file mode 100644 index aaeee9210..000000000 --- a/test/schema/test_schema.py +++ /dev/null @@ -1,675 +0,0 @@ -import os -import unittest -from copy import deepcopy -from unittest.mock import patch, Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message, check_startswith_error_message -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.schema import Schema -from weaviate.util import _capitalize_first_letter - -company_test_schema = { - "classes": [ - { - "class": "Company", - "description": "A business that acts in the market", - "properties": [ - { - "name": "name", - "description": "The name under which the company is known", - "dataType": ["text"], - }, - { - "name": "legalBody", - "description": "The legal body under which the company maintains its business", - "dataType": ["text"], - }, - { - "name": "hasEmployee", - "description": "The employees of the company", - "dataType": ["Employee"], - }, - ], - }, - { - "class": "Employee", - "description": "An employee of the company", - "properties": [ - { - "name": "name", - "description": "The name of the employee", - "dataType": ["text"], - }, - { - "name": "job", - "description": "the job description of the employee", - "dataType": ["text"], - }, - { - "name": "yearsInTheCompany", - "description": "The number of years this employee has worked in the company", - "dataType": ["int"], - }, - ], - }, - ] -} - -# A test schema as it was returned from a real weaviate instance -persons_return_test_schema = { - "classes": [ - { - "class": "Person", - "description": "A person such as humans or personality known through culture", - "properties": [ - {"dataType": ["text"], "description": "The name of this person", "name": "name"} - ], - }, - { - "class": "Group", - "description": "A set of persons who are associated with each other over some common properties", - "properties": [ - { - "dataType": ["text"], - "description": "The name under which this group is known", - "name": "name", - }, - { - "dataType": ["Person"], - "description": "The persons that are part of this group", - "name": "members", - }, - ], - }, - ], -} - -schema_company_local = { # NOTE: should be the same as file schema_company.json - "classes": [ - { - "class": "Company", - "description": "A business that acts in the market", - "properties": [ - { - "name": "name", - "description": "The name under which the company is known", - "dataType": ["text"], - }, - { - "name": "legalBody", - "description": "The legal body under which the company maintains its business", - "dataType": ["text"], - }, - { - "name": "hasEmployee", - "description": "The employees of the company", - "dataType": ["Employee"], - }, - ], - }, - { - "class": "Employee", - "description": "An employee of the company", - "properties": [ - {"name": "name", "description": "The name of the employee", "dataType": ["text"]}, - { - "name": "job", - "description": "the job description of the employee", - "dataType": ["text"], - }, - { - "name": "yearsInTheCompany", - "description": "The number of years this employee has worked in the company", - "dataType": ["int"], - }, - ], - }, - ] -} - - -class TestSchema(unittest.TestCase): - def test_create(self): - """ - Test the `create` method. - """ - - schema = Schema(Mock()) - - # mock function calls - mock_primitive = Mock() - mock_complex = Mock() - schema._create_classes_with_primitives = mock_primitive - schema._create_complex_properties_from_classes = mock_complex - - schema.create("test/schema/schema_company.json") # with read from file - - mock_primitive.assert_called_with(schema_company_local["classes"]) - mock_complex.assert_called_with(schema_company_local["classes"]) - - def test_create_class(self): - """ - Test the `create_class` method. - """ - - schema = Schema(Mock()) - - # mock function calls - mock_primitive = Mock() - mock_complex = Mock() - schema._create_class_with_primitives = mock_primitive - schema._create_complex_properties_from_class = mock_complex - - schema.create_class(company_test_schema["classes"][0]) - - mock_primitive.assert_called_with(company_test_schema["classes"][0]) - mock_complex.assert_called_with(company_test_schema["classes"][0]) - - @patch("weaviate.schema.crud_schema.Schema.get") - def test_update_config(self, mock_schema): - """ - Test the `update_config` method. - """ - - # invalid calls - requests_error_message = "Class schema configuration could not be updated." - unexpected_error_msg = "Update class schema configuration" - - mock_schema.return_value = { - "class": "Test", - "vectorIndexConfig": {"test1": "Test1", "test2": 2}, - } - mock_conn = mock_connection_func("put", side_effect=RequestsConnectionError("Test!")) - schema = Schema(mock_conn) - with self.assertRaises(RequestsConnectionError) as error: - schema.update_config("Test", {"vectorIndexConfig": {"test2": "Test2"}}) - check_error_message(self, error, requests_error_message) - mock_conn.put.assert_called_with( - path="/schema/Test", - weaviate_object={ - "class": "Test", - "vectorIndexConfig": {"test1": "Test1", "test2": "Test2"}, - }, - ) - - mock_schema.return_value = { - "class": "Test", - "vectorIndexConfig": {"test1": "Test1", "test2": 2}, - } - mock_conn = mock_connection_func("put", status_code=404) - schema = Schema(mock_conn) - with self.assertRaises(UnexpectedStatusCodeException) as error: - schema.update_config("Test", {"vectorIndexConfig": {"test3": True}}) - check_startswith_error_message(self, error, unexpected_error_msg) - mock_conn.put.assert_called_with( - path="/schema/Test", - weaviate_object={ - "class": "Test", - "vectorIndexConfig": {"test1": "Test1", "test2": 2, "test3": True}, - }, - ) - - # valid calls - mock_schema.return_value = { - "class": "Test", - "vectorIndexConfig": {"test1": "Test1", "test2": 2}, - } - mock_conn = mock_connection_func("put") - schema = Schema(mock_conn) - schema.update_config("Test", {}) - mock_conn.put.assert_called_with( - path="/schema/Test", - weaviate_object={"class": "Test", "vectorIndexConfig": {"test1": "Test1", "test2": 2}}, - ) - - # with uncapitalized class_name - mock_schema.return_value = { - "class": "Test", - "vectorIndexConfig": {"test1": "Test1", "test2": 2}, - } - mock_conn = mock_connection_func("put") - schema = Schema(mock_conn) - schema.update_config("test", {}) - mock_conn.put.assert_called_with( - path="/schema/Test", - weaviate_object={"class": "Test", "vectorIndexConfig": {"test1": "Test1", "test2": 2}}, - ) - - def test_get(self): - """ - Test the `get` method. - """ - - # invalid calls - requests_error_message = "Schema could not be retrieved." - unexpected_error_msg = "Get schema" - type_error_msg = lambda dt: f"'class_name' argument must be of type `str`! Given type: {dt}" - - mock_conn = mock_connection_func("get", side_effect=RequestsConnectionError("Test!")) - schema = Schema(mock_conn) - with self.assertRaises(RequestsConnectionError) as error: - schema.get() - check_error_message(self, error, requests_error_message) - - mock_conn = mock_connection_func("get", status_code=404) - schema = Schema(mock_conn) - with self.assertRaises(UnexpectedStatusCodeException) as error: - schema.get() - check_startswith_error_message(self, error, unexpected_error_msg) - - connection_mock_file = mock_connection_func( - "get", status_code=200, return_json={"Test": "OK!"} - ) - schema = Schema(connection_mock_file) - with self.assertRaises(TypeError) as error: - schema.get(1234) - check_error_message(self, error, type_error_msg(int)) - - # valid calls - - self.assertEqual(schema.get(), {"Test": "OK!"}) - connection_mock_file.get.assert_called_with( - path="/schema", - ) - - self.assertEqual(schema.get("Artist"), {"Test": "OK!"}) - connection_mock_file.get.assert_called_with(path="/schema/Artist") - - # with uncapitalized class_name - self.assertEqual(schema.get("artist"), {"Test": "OK!"}) - connection_mock_file.get.assert_called_with(path="/schema/Artist") - - def test_contains(self): - """ - Test the `contains` method. - """ - - # If a schema is present it should return true otherwise false - # 1. test schema is present: - - schema = Schema(mock_connection_func("get", return_json=persons_return_test_schema)) - self.assertTrue(schema.contains()) - - # 2. test no schema is present: - - schema = Schema(mock_connection_func("get", return_json={"classes": []})) - self.assertFalse(schema.contains()) - - # 3. test with 'schema' argument - ## Test weaviate.schema.contains specific schema. - - schema = Schema(mock_connection_func("get", return_json=persons_return_test_schema)) - self.assertFalse(schema.contains(company_test_schema)) - subset_schema = { - "classes": [ - { - "class": "Person", - "description": "", - "properties": [{"dataType": ["text"], "description": "", "name": "name"}], - } - ] - } - self.assertTrue(schema.contains(subset_schema)) - - ## Test weaviate.schema.contains schema from file. - - schema = Schema(mock_connection_func("get", return_json=persons_return_test_schema)) - schema_json_file = os.path.join(os.path.dirname(__file__), "schema_company.json") - self.assertFalse(schema.contains(schema_json_file)) - - schema = Schema(mock_connection_func("get", return_json=company_test_schema)) - self.assertTrue(schema.contains(schema_json_file)) - - def test_delete_class_input(self): - """ - Test the 'delete_class` method. - """ - - schema = Schema(Mock()) - - # invalid calls - type_error_message = lambda t: f"Class name was {t} instead of str" - requests_error_message = "Deletion of class." - - with self.assertRaises(TypeError) as error: - schema.delete_class(1) - check_error_message(self, error, type_error_message(int)) - - schema = Schema( - mock_connection_func("delete", side_effect=RequestsConnectionError("Test!")) - ) - with self.assertRaises(RequestsConnectionError) as error: - schema.delete_class("uuid") - check_error_message(self, error, requests_error_message) - - schema = Schema(mock_connection_func("delete", status_code=404)) - with self.assertRaises(UnexpectedStatusCodeException) as error: - schema.delete_class("uuid") - check_startswith_error_message(self, error, "Delete class from schema") - - # valid calls - mock_conn = mock_connection_func("delete", status_code=200) - schema = Schema(mock_conn) - schema.delete_class("Test") - mock_conn.delete.assert_called_with(path="/schema/Test") - - # with uncapitalized class_name - mock_conn = mock_connection_func("delete", status_code=200) - schema = Schema(mock_conn) - schema.delete_class("test") - mock_conn.delete.assert_called_with(path="/schema/Test") - - def test_delete_everything(self): - """ - Test the `delete_all` method. - """ - - mock_connection = mock_connection_func("get", return_json=company_test_schema) - mock_connection = mock_connection_func("delete", connection_mock=mock_connection) - schema = Schema(mock_connection) - - schema.delete_all() - self.assertEqual(mock_connection.get.call_count, 1) - self.assertEqual(mock_connection.delete.call_count, 2) - - def test__create_complex_properties_from_classes(self): - """ - Test the `_create_complex_properties_from_classes` method. - """ - - schema = Schema(Mock()) - - mock_complex = Mock() - schema._create_complex_properties_from_class = mock_complex - - schema._create_complex_properties_from_classes(list("Test!")) - self.assertEqual(mock_complex.call_count, 5) - - def test__create_complex_properties_from_class(self): - """ - Test the `_create_complex_properties_from_class` method. - """ - - # valid calls - - def helper_test(nr_calls=1): - mock_rest = mock_connection_func("post") - schema = Schema(mock_rest) - schema._create_complex_properties_from_class(properties) - self.assertEqual(mock_rest.post.call_count, nr_calls) - properties_copy = deepcopy(properties["properties"]) - for prop in properties_copy: - prop["dataType"] = [_capitalize_first_letter(dt) for dt in prop["dataType"]] - mock_rest.post.assert_called_with( - path="/schema/" + _capitalize_first_letter(properties["class"]) + "/properties", - weaviate_object=properties_copy[0], - ) - - # no `properties` key - mock_rest = mock_connection_func("post") - schema = Schema(mock_rest) - - schema._create_complex_properties_from_class({}) - self.assertEqual(mock_rest.run_rest.call_count, 0) - - # no COMPLEX properties - properties = {"properties": [{"dataType": ["text"]}]} - schema._create_complex_properties_from_class(properties) - self.assertEqual(mock_rest.post.call_count, 0) - - properties = {"properties": [{"dataType": ["text"]}, {"dataType": ["string"]}]} - schema._create_complex_properties_from_class(properties) - self.assertEqual(mock_rest.post.call_count, 0) - - # COMPLEX properties - properties = { - "class": "TestClass", - "properties": [ - {"dataType": ["Test"], "description": "test description", "name": "test_prop"}, - ], - } - mock_rest = mock_connection_func("post") - schema = Schema(mock_rest) - schema._create_complex_properties_from_class(properties) - self.assertEqual(mock_rest.post.call_count, 1) - - properties = { - "class": "TestClass", - "properties": [ - {"dataType": ["Test"], "description": "test description", "name": "test_prop"}, - ], - } - helper_test() - - properties["properties"][0]["indexInverted"] = True - helper_test() - - properties["properties"][0]["moduleConfig"] = {"test": "ok!"} - helper_test() - - properties["properties"].append(properties["properties"][0]) # add another property - properties["properties"].append(properties["properties"][0]) # add another property - helper_test(3) - - # with uncapitalized class_name - properties["class"] = "testClass" - helper_test(3) - - properties = { - "class": "testClass", - "properties": [ - { - "dataType": ["test", "myTest"], - "description": "test description", - "name": "test_prop", - }, - ], - } - - # invalid calls - requests_error_message = "Property may not have been created properly." - - mock_rest = mock_connection_func("post", side_effect=RequestsConnectionError("TEST1")) - schema = Schema(mock_rest) - with self.assertRaises(RequestsConnectionError) as error: - schema._create_complex_properties_from_class(properties) - check_error_message(self, error, requests_error_message) - - mock_rest = mock_connection_func("post", status_code=404) - schema = Schema(mock_rest) - with self.assertRaises(UnexpectedStatusCodeException) as error: - schema._create_complex_properties_from_class(properties) - check_startswith_error_message(self, error, "Add properties to classes") - - def test__create_class_with_primitives(self): - """ - Test the `_create_class_with_primitives` method. - """ - - # valid calls - def helper_test(test_class, test_class_call): - mock_rest = mock_connection_func("post") - schema = Schema(mock_rest) - schema._create_class_with_primitives(test_class) - self.assertEqual(mock_rest.post.call_count, 1) - mock_rest.post.assert_called_with( - path="/schema", - weaviate_object=test_class_call, - ) - - test_class = { - "class": "TestClass", - "properties": [ - {"dataType": ["int"], "name": "test_prop", "description": "None"}, - {"dataType": ["Test"], "name": "test_prop", "description": "None"}, - ], - } - test_class_call = { - "class": "TestClass", - "properties": [ - {"dataType": ["int"], "name": "test_prop", "description": "None"}, - ], - } - helper_test(test_class, test_class_call) - - test_class["description"] = "description" - test_class_call["description"] = "description" - helper_test(test_class, test_class_call) - - test_class["description"] = "description" - test_class_call["description"] = "description" - helper_test(test_class, test_class_call) - - test_class["vectorIndexType"] = "vectorIndexType" - test_class_call["vectorIndexType"] = "vectorIndexType" - helper_test(test_class, test_class_call) - - test_class["vectorIndexConfig"] = {"vectorIndexConfig": "vectorIndexConfig"} - test_class_call["vectorIndexConfig"] = {"vectorIndexConfig": "vectorIndexConfig"} - helper_test(test_class, test_class_call) - - test_class["vectorizer"] = "test_vectorizer" - test_class_call["vectorizer"] = "test_vectorizer" - helper_test(test_class, test_class_call) - - test_class["moduleConfig"] = {"moduleConfig": "moduleConfig"} - test_class_call["moduleConfig"] = {"moduleConfig": "moduleConfig"} - helper_test(test_class, test_class_call) - - test_class["shardingConfig"] = {"shardingConfig": "shardingConfig"} - test_class_call["shardingConfig"] = {"shardingConfig": "shardingConfig"} - helper_test(test_class, test_class_call) - - # multiple properties do not imply multiple `run_rest` calls - test_class["properties"].append(test_class["properties"][0]) # add another property - test_class["properties"].append(test_class["properties"][0]) # add another property - test_class_call["properties"].append(test_class["properties"][0]) # add another property - test_class_call["properties"].append(test_class["properties"][0]) # add another property - helper_test(test_class, test_class_call) - - # with uncapitalized class_name - test_class["class"] = "testClass" - helper_test(test_class, test_class_call) - - # invalid calls - requests_error_message = "Class may not have been created properly." - - mock_rest = mock_connection_func("post", side_effect=RequestsConnectionError("TEST1")) - schema = Schema(mock_rest) - with self.assertRaises(RequestsConnectionError) as error: - schema._create_class_with_primitives(test_class) - check_error_message(self, error, requests_error_message) - - mock_rest = mock_connection_func("post", status_code=404) - schema = Schema(mock_rest) - with self.assertRaises(UnexpectedStatusCodeException) as error: - schema._create_class_with_primitives(test_class) - check_startswith_error_message(self, error, "Create class") - - def test__create_classes_with_primitives(self): - """ - Test the `_create_classes_with_primitives` method. - """ - - schema = Schema(Mock()) - - mock_primitive = Mock() - schema._create_class_with_primitives = mock_primitive - - schema._create_classes_with_primitives(list("Test!!")) - self.assertEqual(mock_primitive.call_count, 6) - - def test__property_is_primitive(self): - """ - Test the `_property_is_primitive` function. - """ - - from weaviate.schema.crud_schema import _property_is_primitive - - test_types_list = ["NOT Primitive", "Neither this one", "Nor This!"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = ["NOT Primitive", "boolean", "text"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = ["text"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["int"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["number"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["string"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["boolean"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["date"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["geoCoordinates"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["blob"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["phoneNumber"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["int[]", "number[]", "text[]", "string[]", "boolean[]", "date[]"] - self.assertTrue(_property_is_primitive(test_types_list)) - test_types_list = ["int()"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = ["number()"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = ["text()"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = ["string()"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = ["boolean()"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = ["date()"] - self.assertFalse(_property_is_primitive(test_types_list)) - test_types_list = [ - "string", - "int", - "boolean", - "number", - "date", - "text", - "geoCoordinates", - "blob", - "phoneNumber", - "int[]", - "number[]", - "text[]", - "string[]", - "boolean[]", - "date[]", - ] - self.assertTrue(_property_is_primitive(test_types_list)) - - def test__get_primitive_properties(self): - """ - Test the `_get_primitive_properties` function. - """ - - from weaviate.schema.crud_schema import _get_primitive_properties - - test_func = _get_primitive_properties - - properties_list = [] - self.assertEqual(test_func(properties_list), properties_list) - - properties_list = [{"dataType": ["text"]}] - self.assertEqual(test_func(properties_list), properties_list) - - properties_list = [{"dataType": ["text"]}, {"dataType": ["int"]}] - self.assertEqual(test_func(properties_list), properties_list) - - properties_list = [{"dataType": ["Test1"]}, {"dataType": ["Test2"]}] - self.assertEqual(test_func(properties_list), []) - - properties_list = [ - {"dataType": ["text"]}, - {"dataType": ["int"]}, - {"dataType": ["Test1"]}, - {"dataType": ["Test2"]}, - ] - self.assertEqual( - test_func(properties_list), [{"dataType": ["text"]}, {"dataType": ["int"]}] - ) diff --git a/test/test_client.py b/test/test_client.py deleted file mode 100644 index cde8c90cb..000000000 --- a/test/test_client.py +++ /dev/null @@ -1,246 +0,0 @@ -import unittest -from sys import platform -from unittest.mock import patch, Mock - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from test.util import mock_connection_func, check_error_message -from weaviate import Client -from weaviate.config import ConnectionConfig -from weaviate.embedded import EmbeddedOptions, EmbeddedDB -from weaviate.exceptions import UnexpectedStatusCodeException - - -@patch("weaviate.client.Connection", Mock) -class TestWeaviateClient(unittest.TestCase): - @patch("weaviate.client.Client.get_meta", return_value={"version": "1.13.2"}) - def test___init__(self, mock_get_meta_method): - """ - Test the `__init__` method. - """ - - type_error_message = "Either url or embedded options must be present." - # test invalid calls - with self.assertRaises(TypeError) as error: - Client(None) - check_error_message(self, error, type_error_message) - with self.assertRaises(TypeError) as error: - Client(42) - check_error_message(self, error, "URL is expected to be string but is " + str(int)) - - # test valid calls - with patch( - "weaviate.client.Connection", - Mock(side_effect=lambda **kwargs: Mock(timeout_config=kwargs["timeout_config"])), - ) as mock_obj: - Client( - url="some_URL", - auth_client_secret=None, - timeout_config=(1, 2), - additional_headers=None, - startup_period=None, - ) - mock_obj.assert_called_with( - url="some_URL", - auth_client_secret=None, - timeout_config=(1, 2), - proxies=None, - trust_env=False, - additional_headers=None, - startup_period=None, - embedded_db=None, - grcp_port=None, - connection_config=ConnectionConfig(), - ) - - with patch( - "weaviate.client.Connection", - Mock(side_effect=lambda **kwargs: Mock(timeout_config=kwargs["timeout_config"])), - ) as mock_obj: - Client( - url="some_URL", - auth_client_secret=None, - timeout_config=(1, 2), - additional_headers={"Test": True}, - startup_period=None, - ) - mock_obj.assert_called_with( - url="some_URL", - auth_client_secret=None, - timeout_config=(1, 2), - proxies=None, - trust_env=False, - additional_headers={"Test": True}, - startup_period=None, - embedded_db=None, - grcp_port=None, - connection_config=ConnectionConfig(), - ) - - with patch( - "weaviate.client.Connection", - Mock(side_effect=lambda **kwargs: Mock(timeout_config=kwargs["timeout_config"])), - ) as mock_obj: - Client( - "some_URL/", auth_client_secret=None, timeout_config=(5, 20), startup_period=None - ) - mock_obj.assert_called_with( - url="some_URL", - auth_client_secret=None, - timeout_config=(5, 20), - proxies=None, - trust_env=False, - additional_headers=None, - startup_period=None, - embedded_db=None, - grcp_port=None, - connection_config=ConnectionConfig(), - ) - - with patch( - "weaviate.client.Connection", - Mock(side_effect=lambda **kwargs: Mock(timeout_config=kwargs["timeout_config"])), - ) as mock_obj: - Client( - url="some_URL", - auth_client_secret=None, - timeout_config=(1, 2), - proxies={"http": "test"}, - trust_env=True, - additional_headers=None, - startup_period=None, - ) - mock_obj.assert_called_with( - url="some_URL", - auth_client_secret=None, - timeout_config=(1, 2), - proxies={"http": "test"}, - trust_env=True, - additional_headers=None, - startup_period=None, - embedded_db=None, - grcp_port=None, - connection_config=ConnectionConfig(), - ) - - if platform == "linux": - with patch( - "weaviate.client.Connection", - Mock(side_effect=lambda **kwargs: Mock(timeout_config=kwargs["timeout_config"])), - ) as mock_obj: - with patch("weaviate.embedded.EmbeddedDB.start") as mocked_start: - Client(embedded_options=EmbeddedOptions()) - args, kwargs = mock_obj.call_args_list[0] - self.assertEqual(kwargs["url"], "http://localhost:8079") - self.assertTrue(isinstance(kwargs["embedded_db"], EmbeddedDB)) - self.assertTrue(kwargs["embedded_db"] is not None) - self.assertEqual(kwargs["embedded_db"].options.port, 8079) - mocked_start.assert_called_once() - - @patch("weaviate.client.Client.get_meta", return_value={"version": "1.13.2"}) - def test_is_ready(self, mock_get_meta_method): - """ - Test the `is_ready` method. - """ - - client = Client("http://localhost:8080") - # Request to weaviate returns 200 - connection_mock = mock_connection_func("get") - client._connection = connection_mock - self.assertTrue(client.is_ready()) # Should be true - connection_mock.get.assert_called_with(path="/.well-known/ready") - - # Request to weaviate returns 404 - connection_mock = mock_connection_func("get", status_code=404) - client._connection = connection_mock - self.assertFalse(client.is_ready()) # Should be false - connection_mock.get.assert_called_with(path="/.well-known/ready") - - # Test exception in connect - connection_mock = mock_connection_func("get", side_effect=RequestsConnectionError("Test")) - client._connection = connection_mock - self.assertFalse(client.is_ready()) - connection_mock.get.assert_called_with(path="/.well-known/ready") - - @patch("weaviate.client.Client.get_meta", return_value={"version": "1.13.2"}) - def test_is_live(self, mock_get_meta): - """ - Test the `is_live` method. - """ - - client = Client("http://localhost:8080") - # Request to weaviate returns 200 - connection_mock = mock_connection_func("get") - client._connection = connection_mock - self.assertTrue(client.is_live()) # Should be true - connection_mock.get.assert_called_with(path="/.well-known/live") - - # Request to weaviate returns 404 - connection_mock = mock_connection_func("get", status_code=404) - client._connection = connection_mock - self.assertFalse(client.is_live()) # Should be false - connection_mock.get.assert_called_with(path="/.well-known/live") - - def test_get_meta(self): - """ - Test the `get_meta` method. - """ - - # client = Client("http://localhost:8080") - # # Request to weaviate returns 200 - # connection_mock = mock_connection_func('get', return_json="OK!") - # client._connection = connection_mock - # self.assertEqual(client.get_meta(), "OK!") - # connection_mock.get.assert_called_with( - # path="/meta" - # ) - - # # Request to weaviate returns 404 - # connection_mock = mock_connection_func('get', status_code=404) - # client._connection = connection_mock - # with self.assertRaises(UnexpectedStatusCodeException) as error: - # client.get_meta() - # error_message = "Meta endpoint! Unexpected status code: 404, with response body: None" - # check_error_message(self, error, error_message) - # connection_mock.get.assert_called_with( - # path="/meta" - # ) - - @patch("weaviate.client.Client.get_meta", return_value={"version": "1.13.2"}) - def test_get_open_id_configuration(self, mock_get_meta): - """ - Test the `get_open_id_configuration` method. - """ - - client = Client("http://localhost:8080") - # Request to weaviate returns 200 - connection_mock = mock_connection_func("get", return_json={"status": "OK!"}) - client._connection = connection_mock - self.assertEqual(client.get_open_id_configuration(), {"status": "OK!"}) - connection_mock.get.assert_called_with(path="/.well-known/openid-configuration") - - # Request to weaviate returns 404 - connection_mock = mock_connection_func("get", status_code=404) - client._connection = connection_mock - self.assertIsNone(client.get_open_id_configuration()) - connection_mock.get.assert_called_with(path="/.well-known/openid-configuration") - - # Request to weaviate returns 204 - connection_mock = mock_connection_func("get", status_code=204) - client._connection = connection_mock - with self.assertRaises(UnexpectedStatusCodeException) as error: - client.get_open_id_configuration() - error_message = "Meta endpoint! Unexpected status code: 204, with response body: None." - check_error_message(self, error, error_message) - connection_mock.get.assert_called_with(path="/.well-known/openid-configuration") - - @patch("weaviate.client.Client.get_meta", return_value={"version": "1.13.2"}) - def test_timeout_config(self, mock_get_meta): - """ - Test the `set_timeout_config` method. - """ - - client = Client("http://some_url.com", auth_client_secret=None, timeout_config=(1, 2)) - self.assertEqual(client.timeout_config, (1, 2)) - client.timeout_config = (4, 20) # ;) - self.assertEqual(client.timeout_config, (4, 20)) diff --git a/test/test_embedded.py b/test/test_embedded.py deleted file mode 100644 index f3e498431..000000000 --- a/test/test_embedded.py +++ /dev/null @@ -1,383 +0,0 @@ -import os -import signal -import socket -import tarfile -import time -from pathlib import Path -from sys import platform -from unittest.mock import patch - -import pytest -import requests -import uuid -from pytest_httpserver import HTTPServer -from werkzeug import Request, Response - -import weaviate -from weaviate import embedded -from weaviate.embedded import EmbeddedDB, EmbeddedOptions -from weaviate.exceptions import WeaviateEmbeddedInvalidVersionError, WeaviateStartUpError - -if platform != "linux" and platform != "darwin": - pytest.skip("Currently only supported on linux", allow_module_level=True) - - -def test_embedded__init__(tmp_path): - assert ( - EmbeddedDB(EmbeddedOptions(port=8079, persistence_data_path=tmp_path)).options.port == 8079 - ) - - -def test_embedded__init__non_default_port(tmp_path): - assert ( - EmbeddedDB(EmbeddedOptions(port=30666, persistence_data_path=tmp_path)).options.port - == 30666 - ) - - -def test_embedded_ensure_binary_exists(tmp_path): - bin_path = tmp_path / "notcreated-yet/bin/weaviate-embedded" - assert bin_path.is_file, False - embedded_db = EmbeddedDB( - EmbeddedOptions(binary_path=str(bin_path), persistence_data_path=tmp_path / "2") - ) - embedded_db.ensure_weaviate_binary_exists() - assert Path(embedded_db.options.binary_path).is_file, True - - -def test_version_parsing(tmp_path): - bin_path = tmp_path / "bin" - embedded_db = EmbeddedDB( - EmbeddedOptions( - binary_path=str(bin_path), - version="https://github.com/weaviate/weaviate/releases/download/v1.18.1/weaviate-v1.18.1-linux-amd64.tar.gz", - persistence_data_path=tmp_path / "2", - ) - ) - embedded_db.ensure_weaviate_binary_exists() - embedded_file_name = list(bin_path.iterdir()) - assert len(embedded_file_name) == 1 # .tgz file was deleted - assert "v1.18.1" in str(embedded_file_name[0]) - - -def test_download_no_version_parsing(httpserver: HTTPServer, tmp_path): - """Test downloading weaviate from a non-github url.""" - - def handler(request: Request): - with open(Path(tmp_path, "weaviate"), "w") as _: - with tarfile.open(Path(tmp_path, "tmp_weaviate.tar.gz"), "w:gz") as tar: - tar.add(Path(tmp_path, "weaviate"), arcname="weaviate") - - return Response(open(Path(tmp_path, "tmp_weaviate.tar.gz"), mode="rb")) - - httpserver.expect_request("/tmp_weaviate.tar.gz").respond_with_handler(handler) - - bin_path = tmp_path / "bin" - embedded_db = EmbeddedDB( - EmbeddedOptions( - binary_path=str(bin_path), - version=httpserver.url_for("/tmp_weaviate.tar.gz"), - persistence_data_path=tmp_path / "2", - ) - ) - embedded_db.ensure_weaviate_binary_exists() - embedded_file_name = list(bin_path.iterdir()) - assert len(embedded_file_name) == 1 # .tgz file was deleted - - -def test_embedded_ensure_binary_exists_same_as_tar_binary_name(tmp_path): - bin_path = tmp_path / "notcreated-yet/bin/weaviate" - assert bin_path.is_file, False - embedded_db = EmbeddedDB( - EmbeddedOptions(binary_path=str(bin_path), persistence_data_path=tmp_path) - ) - embedded_db.ensure_weaviate_binary_exists() - assert Path(embedded_db.options.binary_path).is_file, True - - -@pytest.fixture(scope="session") -def embedded_db_binary_path(tmp_path_factory: pytest.TempPathFactory): - embedded.weaviate_binary_path = ( - tmp_path_factory.mktemp("embedded-test") / "weaviate-embedded-binary" - ) - - -@pytest.mark.parametrize( - "options", [EmbeddedOptions(), EmbeddedOptions(port=30666, grpc_port=50046)] -) -def test_embedded_end_to_end(options: EmbeddedDB, tmp_path): - try: - options.binary_path = tmp_path - options.persistence_data_path = tmp_path - embedded_db = EmbeddedDB(options=options) - assert embedded_db.is_listening() is False - with pytest.raises(WeaviateStartUpError): - with patch("time.sleep") as mocked_sleep: - embedded_db.wait_till_listening() - mocked_sleep.assert_has_calls([0.1] * 300) - - embedded_db.ensure_running() - assert embedded_db.is_listening() is True - with patch("weaviate.logger.logger.info") as mocked_print: - embedded_db.start() - mocked_print.assert_called_once_with( - f"embedded weaviate is already listening on port {options.port}" - ) - - # killing the process should restart it again when ensure running is called - os.kill(embedded_db.process.pid, signal.SIGTERM) - time.sleep(0.2) - assert embedded_db.is_listening() is False - embedded_db.ensure_running() - assert embedded_db.is_listening() is True - finally: - embedded_db.stop() - - -def test_embedded_multiple_instances(tmp_path_factory: pytest.TempPathFactory): - embedded_db = EmbeddedDB( - EmbeddedOptions( - port=30662, - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - additional_env_vars={"GRPC_PORT": "50053"}, - grpc_port=50053, - ) - ) - embedded_db2 = EmbeddedDB( - EmbeddedOptions( - port=30663, - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - additional_env_vars={"GRPC_PORT": "50054"}, - grpc_port=50054, - ) - ) - try: - embedded_db.ensure_running() - assert embedded_db.is_listening() is True - embedded_db2.ensure_running() - assert embedded_db2.is_listening() is True - finally: - embedded_db.stop() - embedded_db2.stop() - - -def test_embedded_different_versions(tmp_path_factory: pytest.TempPathFactory): - client1 = weaviate.Client( - embedded_options=EmbeddedOptions( - port=30664, - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="https://github.com/weaviate/weaviate/releases/download/v1.18.1/weaviate-v1.18.1-linux-amd64.tar.gz", - ) - ) - client2 = weaviate.Client( - embedded_options=EmbeddedOptions( - port=30665, - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="https://github.com/weaviate/weaviate/releases/download/v1.18.0/weaviate-v1.18.0-linux-amd64.tar.gz", - ) - ) - try: - meta1 = client1.get_meta() - assert meta1["version"] == "1.18.1" - meta2 = client2.get_meta() - assert meta2["version"] == "1.18.0" - finally: - client1._connection.embedded_db.stop() - client2._connection.embedded_db.stop() - - -def test_custom_env_vars(tmp_path_factory: pytest.TempPathFactory): - client = weaviate.Client( - embedded_options=EmbeddedOptions( - binary_path=tmp_path_factory.mktemp("bin"), - additional_env_vars={"ENABLE_MODULES": "", "GRPC_PORT": "50057"}, - persistence_data_path=tmp_path_factory.mktemp("data"), - port=30666, - ) - ) - try: - meta = client.get_meta() - assert len(meta["modules"]) == 0 - finally: - client._connection.embedded_db.stop() - - -def test_weaviate_state(tmp_path_factory: pytest.TempPathFactory) -> None: - """Test that weaviate keeps the state between different runs.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - port = 36545 - data_path = tmp_path_factory.mktemp("data") - client = weaviate.Client( - embedded_options=EmbeddedOptions( - binary_path=tmp_path_factory.mktemp("bin"), - port=port, - persistence_data_path=data_path, - additional_env_vars={"GRPC_PORT": "50058"}, - grpc_port=50058, - ), - ) - client.data_object.create({"name": "Name"}, "Person", uuid.uuid4()) - assert sock.connect_ex(("127.0.0.1", port)) == 0 # running - - client._connection.embedded_db.stop() - del client - time.sleep(5) # give weaviate time to shut down - - assert sock.connect_ex(("127.0.0.1", port)) != 0 # not running anymore - - client = weaviate.Client( - embedded_options=EmbeddedOptions( - binary_path=tmp_path_factory.mktemp("bin"), - port=port, - persistence_data_path=data_path, - additional_env_vars={"GRPC_PORT": "50059"}, - grpc_port=50059, - ) - ) - count = client.query.aggregate("Person").with_meta_count().do() - assert count["data"]["Aggregate"]["Person"][0]["meta"]["count"] == 1 - - client._connection.embedded_db.stop() - - -def test_version(tmp_path_factory: pytest.TempPathFactory): - client = weaviate.Client( - embedded_options=EmbeddedOptions( - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="1.18.2", - port=30667, - ) - ) - try: - meta = client.get_meta() - assert meta["version"] == "1.18.2" - finally: - client._connection.embedded_db.stop() - - -def test_latest(tmp_path_factory: pytest.TempPathFactory): - client = weaviate.Client( - embedded_options=EmbeddedOptions( - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="latest", - port=30668, - additional_env_vars={"GRPC_PORT": "50060"}, - grpc_port=50060, - ) - ) - try: - meta = client.get_meta() - latest = requests.get( - "https://api.github.com/repos/weaviate/weaviate/releases/latest" - ).json() - assert "v" + meta["version"] == latest["tag_name"] - finally: - client._connection.embedded_db.stop() - - -@pytest.mark.parametrize( - "version", - [ - "v1.16.6", - "sdgfsdfposdfjpsdf", - "httttp://github.com/weaviate/weaviate/releases/download/v1.18.0/weaviate-v1.18.0-linux-amd64.tar.gz", - "https://github.com/weaviate/weaviate/releases/download/v1.18.0/weaviate-v1.18.0-linux-amd64.tar", - ], -) -def test_invalid_version(tmp_path_factory: pytest.TempPathFactory, version): - with pytest.raises(WeaviateEmbeddedInvalidVersionError): - weaviate.Client( - embedded_options=EmbeddedOptions( - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version=version, - ) - ) - - -def test_embedded_with_grpc_port(tmp_path_factory: pytest.TempPathFactory): - client = weaviate.Client( - embedded_options=EmbeddedOptions( - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="latest", - port=30668, - grpc_port=50061, - ) - ) - try: - assert client.is_ready() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1.0) # we're only pinging the port, 1s is plenty - - assert sock.connect_ex(("127.0.0.1", 50061)) == 0 # running - finally: - client._connection.embedded_db.stop() - - -def test_embedded_v4_with_grpc_port(tmp_path_factory: pytest.TempPathFactory): - client = weaviate.WeaviateClient( - embedded_options=EmbeddedOptions( - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="latest", - port=30668, - grpc_port=50061, - ) - ) - try: - client.connect() - assert client.is_ready() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1.0) # we're only pinging the port, 1s is plenty - - assert sock.connect_ex(("127.0.0.1", 50061)) == 0 # running - finally: - client._connection.embedded_db.stop() - - -def test_embedded_with_grpc_port_default(tmp_path_factory: pytest.TempPathFactory): - client = weaviate.Client( - embedded_options=EmbeddedOptions( - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="latest", - port=30669, - ) - ) - try: - assert client.is_ready() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1.0) # we're only pinging the port, 1s is plenty - - assert sock.connect_ex(("127.0.0.1", 50060)) == 0 # running - finally: - client._connection.embedded_db.stop() - - -def test_embedded_stop(tmp_path_factory: pytest.TempPathFactory): - client = weaviate.Client( - embedded_options=EmbeddedOptions( - persistence_data_path=tmp_path_factory.mktemp("data"), - binary_path=tmp_path_factory.mktemp("bin"), - version="latest", - port=30668, - grpc_port=50060, - ) - ) - try: - assert client.is_ready() - - assert client._connection.embedded_db.process is not None - client._connection.embedded_db.stop() - assert client._connection.embedded_db.process is None - client._connection.embedded_db.stop() - assert client._connection.embedded_db.process is None - finally: - client._connection.embedded_db.stop() diff --git a/test/test_exceptions.py b/test/test_exceptions.py deleted file mode 100644 index 7e2412486..000000000 --- a/test/test_exceptions.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest -from unittest.mock import Mock - -from requests import exceptions - -from weaviate.exceptions import ( - UnexpectedStatusCodeException, - ObjectAlreadyExistsException, - AuthenticationFailedException, - SchemaValidationException, -) - - -class TestExceptions(unittest.TestCase): - def test_unexpected_status_code(self): - """ - Test the `UnexpectedStatusCodeException` exception. - """ - - # with .json() exception raised - response = Mock() - response.json = Mock(side_effect=exceptions.JSONDecodeError("test", "", 0)) - response.status_code = 1234 - exception = UnexpectedStatusCodeException(message="Test message", response=response) - self.assertEqual( - str(exception), "Test message! Unexpected status code: 1234, with response body: None." - ) - self.assertEqual(exception.status_code, response.status_code) - - # with .json() value - response = Mock() - response.json = Mock() - response.json.return_value = {"test": "OK!"} - response.status_code = 4321 - exception = UnexpectedStatusCodeException(message="Second test message", response=response) - self.assertEqual( - str(exception), - "Second test message! Unexpected status code: 4321, with response body: {'test': 'OK!'}.", - ) - self.assertEqual(exception.status_code, response.status_code) - - def test_object_already_exists(self): - """ - Test the `ObjectAlreadyExistsException` exception. - """ - - exception = ObjectAlreadyExistsException("Test") - self.assertEqual(str(exception), "Test") - - def test_authentication_failed(self): - """ - Test the `AuthenticationFailedException` exception. - """ - - exception = AuthenticationFailedException("Test") - self.assertEqual(str(exception), "Test") - - def test_schema_validation(self): - """ - Test the `SchemaValidationException` exception. - """ - - exception = SchemaValidationException("Test") - self.assertEqual(str(exception), "Test") diff --git a/test/test_util.py b/test/test_util.py index 25b8ccc04..a491361fd 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -2,6 +2,7 @@ import uuid as uuid_lib from copy import deepcopy from unittest.mock import patch, Mock + import pytest from test.util import check_error_message @@ -10,13 +11,11 @@ generate_uuid5, image_decoder_b64, image_encoder_b64, - generate_local_beacon, is_object_url, is_weaviate_object_url, get_vector, get_valid_uuid, get_domain_from_weaviate_url, - _get_dict_from_object, _is_sub_schema, parse_version_string, is_weaviate_too_old, @@ -113,108 +112,6 @@ class TestUtil(unittest.TestCase): - def test_generate_local_beacon(self): - """ - Test the `generate_local_beacon` function. - """ - - type_error_message = "Expected to_object_uuid of type str or uuid.UUID" - value_error_message = "Uuid does not have the proper form" - # wrong data type - with self.assertRaises(TypeError) as error: - generate_local_beacon(None) - check_error_message(self, error, type_error_message) - # wrong value - with self.assertRaises(ValueError) as error: - generate_local_beacon("Leeroy Jenkins") - check_error_message(self, error, value_error_message) - - beacon = generate_local_beacon("fcf33178-1b5d-5174-b2e7-04a2129dd35a") - self.assertTrue("beacon" in beacon) - self.assertEqual( - beacon["beacon"], "weaviate://localhost/fcf33178-1b5d-5174-b2e7-04a2129dd35a" - ) - - beacon = generate_local_beacon("fcf33178-1b5d-5174-b2e7-04a2129dd35b") - self.assertTrue("beacon" in beacon) - self.assertEqual( - beacon["beacon"], "weaviate://localhost/fcf33178-1b5d-5174-b2e7-04a2129dd35b" - ) - - beacon = generate_local_beacon("fcf331781b5d5174b2e704a2129dd35b") - self.assertTrue("beacon" in beacon) - self.assertEqual( - beacon["beacon"], "weaviate://localhost/fcf33178-1b5d-5174-b2e7-04a2129dd35b" - ) - - beacon = generate_local_beacon(uuid_lib.UUID("fcf33178-1b5d-5174-b2e7-04a2129dd35b")) - self.assertTrue("beacon" in beacon) - self.assertEqual( - beacon["beacon"], "weaviate://localhost/fcf33178-1b5d-5174-b2e7-04a2129dd35b" - ) - - beacon = generate_local_beacon(uuid_lib.UUID("fcf331781b5d5174b2e704a2129dd35b")) - self.assertTrue("beacon" in beacon) - self.assertEqual( - beacon["beacon"], "weaviate://localhost/fcf33178-1b5d-5174-b2e7-04a2129dd35b" - ) - - beacon = generate_local_beacon(uuid_lib.UUID("fcf331781b5d5174b2e704a2129dd35b"), "Test1") - self.assertTrue("beacon" in beacon) - self.assertEqual( - beacon["beacon"], "weaviate://localhost/Test1/fcf33178-1b5d-5174-b2e7-04a2129dd35b" - ) - - beacon = generate_local_beacon(uuid_lib.UUID("fcf331781b5d5174b2e704a2129dd35b"), "test2") - self.assertTrue("beacon" in beacon) - self.assertEqual( - beacon["beacon"], "weaviate://localhost/Test2/fcf33178-1b5d-5174-b2e7-04a2129dd35b" - ) - - def test__get_dict_from_object(self): - """ - Test the `_get_dict_from_object` function. - """ - - none_error_message = "argument is None" - file_error_message = "No file found at location " - url_error_message = "Could not download file " - type_error_message = ( - "Argument is not of the supported types. Supported types are " - "url or file path as string or schema as dict." - ) - # test wrong type None - with self.assertRaises(TypeError) as error: - _get_dict_from_object(None) - check_error_message(self, error, none_error_message) - # wrong data type - with self.assertRaises(TypeError) as error: - _get_dict_from_object([{"key": 1234}]) - check_error_message(self, error, type_error_message) - # wrong path - with self.assertRaises(ValueError) as error: - _get_dict_from_object("not_a_path_or_url.txt") - check_error_message(self, error, file_error_message + "not_a_path_or_url.txt") - # wrong URL or non existing one or failure of requests.get - with patch("weaviate.util.requests") as mock_obj: - result_mock = Mock() - result_mock.status_code = 404 - mock_obj.get.return_value = result_mock - with self.assertRaises(ValueError) as error: - _get_dict_from_object("http://www.url.com") - check_error_message(self, error, url_error_message + "http://www.url.com") - mock_obj.get.assert_called() - - # valid calls - self.assertEqual(_get_dict_from_object({"key": "val"}), {"key": "val"}) - # read from file - path = "/".join(__file__.split("/")[:-1]) - self.assertEqual( - _get_dict_from_object(f"{path}/schema/schema_company.json"), schema_company - ) - # read from URL - path = "https://raw.githubusercontent.com/semi-technologies/weaviate-python-client/main/test/schema/schema_company.json" - self.assertEqual(_get_dict_from_object(path), schema_company) def test_is_weaviate_object_url(self): """ diff --git a/weaviate/__init__.py b/weaviate/__init__.py index 20f68d83b..704561d09 100644 --- a/weaviate/__init__.py +++ b/weaviate/__init__.py @@ -28,18 +28,15 @@ from . import ( auth, backup, - batch, classes, cluster, collections, config, connect, - data, embedded, exceptions, gql, outputs, - schema, types, ) @@ -65,18 +62,15 @@ "connect_to_weaviate_cloud", "auth", "backup", - "batch", "classes", "cluster", "collections", "config", "connect", - "data", "embedded", "exceptions", "gql", "outputs", - "schema", "types", "use_async_with_custom", "use_async_with_embedded", diff --git a/weaviate/backup/__init__.py b/weaviate/backup/__init__.py index 244d32d8c..4a6735237 100644 --- a/weaviate/backup/__init__.py +++ b/weaviate/backup/__init__.py @@ -2,7 +2,6 @@ Module for backup/restore operations """ -__all__ = ["Backup", "BackupStorage"] +__all__ = ["BackupStorage"] -from weaviate.backup.backup import Backup from weaviate.backup.backup import BackupStorage diff --git a/weaviate/backup/backup.py b/weaviate/backup/backup.py index ebcda1391..80a8d06aa 100644 --- a/weaviate/backup/backup.py +++ b/weaviate/backup/backup.py @@ -4,12 +4,11 @@ from enum import Enum from time import sleep -from typing import Optional, Union, List, Tuple, Any, Dict +from typing import Optional, Union, List, Tuple from pydantic import BaseModel, Field -from requests.exceptions import ConnectionError as RequestsConnectionError -from weaviate.connect import Connection, ConnectionV4 +from weaviate.connect import ConnectionV4 from weaviate.connect.v4 import _ExpectedStatusCodes from weaviate.exceptions import ( WeaviateInvalidInputError, @@ -449,287 +448,6 @@ async def __list_backups(self, backend: BackupStorage) -> List[BackupReturn]: # return await self.__list_backups(backend) -class Backup: - """ - Backup class used to schedule and/or check the status of - a backup process of Weaviate objects. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Classification class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - """ - - self._connection = connection - - def create( - self, - backup_id: str, - backend: str, - include_classes: Union[List[str], str, None] = None, - exclude_classes: Union[List[str], str, None] = None, - wait_for_completion: bool = False, - ) -> dict: - """ - Create a backup of all/per class Weaviate objects. - - Parameters - ---------- - backup_id : str - The identifier name of the backup. - NOTE: Case insensitive. - backend : str - The backend storage where to create the backup. Currently available options are: - "filesystem", "s3", "gcs" and "azure". - NOTE: Case insensitive. - include_classes : Union[List[str], str, None], optional - The class/list of classes to be included in the backup. If not specified all classes - will be included. Either `include_classes` or `exclude_classes` can be set. - By default None. - exclude_classes : Union[List[str], str, None], optional - The class/list of classes to be excluded in the backup. Either `include_classes` or - `exclude_classes` can be set. By default None. - wait_for_completion : bool, optional - Whether to wait until the backup is done. By default False. - - Returns - ------- - dict - Backup creation response. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - TypeError - One of the arguments have a wrong type. - ValueError - 'backend' does not have an accepted value. - """ - - ( - backup_id, - backend, - include_classes, - exclude_classes, - ) = _get_and_validate_create_restore_arguments( - backup_id=backup_id, - backend=backend, - include_classes=include_classes, - exclude_classes=exclude_classes, - wait_for_completion=wait_for_completion, - ) - - payload = { - "id": backup_id, - "include": include_classes, - "exclude": exclude_classes, - } - path = f"/backups/{backend.value}" - - try: - response = self._connection.post( - path=path, - weaviate_object=payload, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Backup creation failed due to connection error." - ) from conn_err - - create_status = _decode_json_response_dict(response, "Backup creation") - assert create_status is not None - if wait_for_completion: - while True: - status: dict = self.get_create_status( - backup_id=backup_id, - backend=backend, - ) - create_status.update(status) - if status["status"] == "SUCCESS": - break - if status["status"] == "FAILED": - raise BackupFailedException(f"Backup failed: {create_status}") - sleep(1) - return create_status - - def get_create_status(self, backup_id: str, backend: str) -> Dict[str, Any]: - """ - Checks if a started classification job has completed. - - Parameters - ---------- - backup_id : str - The identifier name of the backup. - NOTE: Case insensitive. - backend : str - The backend storage where the backup was created. Currently available options are: - "filesystem", "s3", "gcs" and "azure". - NOTE: Case insensitive. - - Returns - ------- - dict - Status of the backup create. - """ - - backup_id, backend = _get_and_validate_get_status( - backup_id=backup_id, - backend=backend, - ) - - path = f"/backups/{backend.value}/{backup_id}" - - try: - response = self._connection.get( - path=path, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Backup creation status failed due to connection error." - ) from conn_err - - typed_response = _decode_json_response_dict(response, "Backup status check") - if typed_response is None: - raise EmptyResponseException() - return typed_response - - def restore( - self, - backup_id: str, - backend: str, - include_classes: Union[List[str], str, None] = None, - exclude_classes: Union[List[str], str, None] = None, - wait_for_completion: bool = False, - ) -> dict: - """ - Restore a backup of all/per class Weaviate objects. - - Parameters - ---------- - backup_id : str - The identifier name of the backup. - NOTE: Case insensitive. - backend : str - The backend storage from where to restore the backup. Currently available options are: - "filesystem", "s3", "gcs" and "azure". - NOTE: Case insensitive. - include_classes : Union[List[str], str, None], optional - The class/list of classes to be included in the backup restore. If not specified all - classes will be included (that were backup-ed). Either `include_classes` or - `exclude_classes` can be set. By default None. - exclude_classes : Union[List[str], str, None], optional - The class/list of classes to be excluded in the backup restore. - Either `include_classes` or `exclude_classes` can be set. By default None. - wait_for_completion : bool, optional - Whether to wait until the backup restore is done. - - Returns - ------- - dict - Backup restore response. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - - ( - backup_id, - backend, - include_classes, - exclude_classes, - ) = _get_and_validate_create_restore_arguments( - backup_id=backup_id, - backend=backend, - include_classes=include_classes, - exclude_classes=exclude_classes, - wait_for_completion=wait_for_completion, - ) - - payload = { - "config": {}, - "include": include_classes, - "exclude": exclude_classes, - } - path = f"/backups/{backend.value}/{backup_id}/restore" - - try: - response = self._connection.post( - path=path, - weaviate_object=payload, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Backup restore failed due to connection error." - ) from conn_err - restore_status = _decode_json_response_dict(response, "Backup restore") - assert restore_status is not None - if wait_for_completion: - while True: - status: dict = self.get_restore_status( - backup_id=backup_id, - backend=backend, - ) - restore_status.update(status) - if status["status"] == "SUCCESS": - break - if status["status"] == "FAILED": - raise BackupFailedException(f"Backup restore failed: {restore_status}") - sleep(1) - return restore_status - - def get_restore_status(self, backup_id: str, backend: str) -> Dict[str, Any]: - """ - Checks if a started classification job has completed. - - Parameters - ---------- - backup_id : str - The identifier name of the backup. - NOTE: Case insensitive. - backend : str - The backend storage where to create the backup. Currently available options are: - "filesystem", "s3", "gcs" and "azure". - NOTE: Case insensitive. - - Returns - ------- - dict - Status of the backup create. - """ - - backup_id, backend = _get_and_validate_get_status( - backup_id=backup_id, - backend=backend, - ) - path = f"/backups/{backend.value}/{backup_id}/restore" - - try: - response = self._connection.get( - path=path, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Backup restore status failed due to connection error." - ) from conn_err - - typed_response = _decode_json_response_dict(response, "Backup restore status check") - if typed_response is None: - raise EmptyResponseException() - return typed_response - - def _get_and_validate_create_restore_arguments( backup_id: str, backend: Union[str, BackupStorage], diff --git a/weaviate/batch/__init__.py b/weaviate/batch/__init__.py deleted file mode 100644 index e34109403..000000000 --- a/weaviate/batch/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Module for uploading objects and references to Weaviate in batches. -""" - -from .crud_batch import Batch, Shard, WeaviateErrorRetryConf - -__all__ = ["Batch", "Shard", "WeaviateErrorRetryConf"] diff --git a/weaviate/batch/crud_batch.py b/weaviate/batch/crud_batch.py deleted file mode 100644 index 89d39282f..000000000 --- a/weaviate/batch/crud_batch.py +++ /dev/null @@ -1,1987 +0,0 @@ -""" -Batch class definitions. -""" - -import datetime -import sys -import threading -import time -import warnings -from collections import deque -from concurrent.futures import ThreadPoolExecutor, as_completed, Future -from dataclasses import dataclass, field -from numbers import Real -from typing import ( - Any, - Callable, - Deque, - Dict, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from requests import ReadTimeout, Response -from requests.exceptions import ConnectionError as RequestsConnectionError -from requests.exceptions import HTTPError as RequestsHTTPError - -from weaviate.connect import Connection -from weaviate.data.replication import ConsistencyLevel -from weaviate.gql.filter import _find_value_type, VALUE_ARRAY_TYPES, WHERE_OPERATORS -from weaviate.types import UUID -from .requests import BatchRequest, ObjectsBatchRequest, ReferenceBatchRequest, BatchResponse -from ..cluster import Cluster -from ..error_msgs import ( - BATCH_REF_DEPRECATION_NEW_V14_CLS_NS_W, - BATCH_REF_DEPRECATION_OLD_V14_CLS_NS_W, - BATCH_EXECUTOR_SHUTDOWN_W, -) -from ..exceptions import UnexpectedStatusCodeException -from ..util import ( - _capitalize_first_letter, - check_batch_result, - _check_positive_num, - _decode_json_response_dict, - _decode_json_response_list, -) -from ..warnings import _Warnings - -BatchRequestType = Union[ObjectsBatchRequest, ReferenceBatchRequest] - - -@dataclass -class Shard: - class_name: str - tenant: Optional[str] = field(default=None) - - def __hash__(self) -> int: - return hash((self.class_name, self.tenant)) - - -@dataclass() -class WeaviateErrorRetryConf: - """Configures how often objects should be retried when Weaviate returns an error and which errors should be included - or excluded. - By default, all errors are retried. - - Parameters - ---------- - number_retries: int - How often a batch that includes objects with errors should be retried. Must be >=1. - errors_to_exclude: Optional[List[str]] - Which errors should NOT be retried. All other errors will be retried. An object will be skipped, when the given - string is part of the weaviate error message. - - Example: errors_to_exclude =["string1", "string2"] will match the error with message "Long error message that - contains string1". - errors_to_include: Optional[List[str]] - Which errors should be retried. All other errors will NOT be retried. An object will be included, when the given - string is part of the weaviate error message. - - Example: errors_to_include =["string1", "string2"] will match the error with message "Long error message that - contains string1". - """ - - number_retries: int = 3 - errors_to_exclude: Optional[List[str]] = None - errors_to_include: Optional[List[str]] = None - - def __post_init__(self) -> None: - if self.errors_to_exclude is not None and self.errors_to_include is not None: - raise ValueError(self.__module__ + " can either include or exclude errors") - - _check_positive_num(self.number_retries, "number_retries", int) - - def check_lists(error_list: Optional[List[str]]) -> None: - if error_list is None: - return - if any(not isinstance(entry, str) for entry in error_list): - raise ValueError("List entries must be strings.") - - check_lists(self.errors_to_exclude) - check_lists(self.errors_to_include) - - if self.errors_to_include is not None and len(self.errors_to_include) == 0: - raise ValueError("errors_to_include has 0 entries and no error will be retried.") - - -class BatchExecutor(ThreadPoolExecutor): - """ - Weaviate Batch Executor to run batch requests in separate thread. - This class implements an additional method `is_shutdown` that us used my the context manager. - """ - - def is_shutdown(self) -> bool: - """ - Check if executor is shutdown. - - Returns - ------- - bool - Whether the BatchExecutor is shutdown. - """ - - return self._shutdown - - -class Batch: - """ - Batch class used to add multiple objects or object references at once into weaviate. - To add data to the Batch use these methods of this class: `add_data_object` and - `add_reference`. This object also stores 2 recommended batch size variables, one for objects - and one for references. The recommended batch size is updated with every batch creation, and - is the number of data objects/references that can be sent/processed by the Weaviate server in - `creation_time` interval (see `configure` or `__call__` method on how to set this value, by - default it is set to 10). The initial value is None/batch_size and is updated with every batch - create methods. The values can be accessed with the getters: `recommended_num_objects` and - `recommended_num_references`. - NOTE: If the UUID of one of the objects already exists then the existing object will be - replaced by the new object. - - This class can be used in 3 ways: - - Case I: - Everything should be done by the user, i.e. the user should add the - objects/object-references and create them whenever the user wants. To create one of the - data type use these methods of this class: `create_objects`, `create_references` and - `flush`. This case has the Batch instance's batch_size set to None (see docs for the - `configure` or `__call__` method). Can be used in a context manager, see below. - - Case II: - Batch auto-creates when full. This can be achieved by setting the Batch instance's - batch_size set to a positive integer (see docs for the `configure` or `__call__` method). - The batch_size in this case corresponds to the sum of added objects and references. - This case does not require the user to create the batch/s, but it can be done. Also to - create non-full batches (last batch/es) that do not meet the requirement to be auto-created - use the `flush` method. Can be used in a context manager, see below. - - Case III: - Similar to Case II but uses dynamic batching, i.e. auto-creates either objects or - references when one of them reached the `recommended_num_objects` or - `recommended_num_references` respectively. See docs for the `configure` or `__call__` - method for how to enable it. - - Context-manager support: Can be use with the `with` statement. When it exists the context- - manager it calls the `flush` method for you. Can be combined with `configure`/`__call__` - method, in order to set it to the desired Case. - - Examples - -------- - Here are examples for each CASE described above. Here `client` is an instance of the - `weaviate.Client`. - - >>> object_1 = '154cbccd-89f4-4b29-9c1b-001a3339d89d' - >>> object_2 = '154cbccd-89f4-4b29-9c1b-001a3339d89c' - >>> object_3 = '254cbccd-89f4-4b29-9c1b-001a3339d89a' - >>> object_4 = '254cbccd-89f4-4b29-9c1b-001a3339d89b' - - For Case I: - - >>> client.batch.shape - (0, 0) - >>> client.batch.add_data_object({}, 'MyClass') - >>> client.batch.add_data_object({}, 'MyClass') - >>> client.batch.add_reference(object_1, 'MyClass', 'myProp', object_2) - >>> client.batch.shape - (2, 1) - >>> client.batch.create_objects() - >>> client.batch.shape - (0, 1) - >>> client.batch.create_references() - >>> client.batch.shape - (0, 0) - >>> client.batch.add_data_object({}, 'MyClass') - >>> client.batch.add_reference(object_3, 'MyClass', 'myProp', object_4) - >>> client.batch.shape - (1, 1) - >>> client.batch.flush() - >>> client.batch.shape - (0, 0) - - Or with a context manager: - - >>> with client.batch as batch: - ... batch.add_data_object({}, 'MyClass') - ... batch.add_reference(object_3, 'MyClass', 'myProp', object_4) - >>> # flush was called - >>> client.batch.shape - (0, 0) - - For Case II: - - >>> client.batch(batch_size=3) - >>> client.batch.shape - (0, 0) - >>> client.batch.add_data_object({}, 'MyClass') - >>> client.batch.add_reference(object_1, 'MyClass', 'myProp', object_2) - >>> client.batch.shape - (1, 1) - >>> client.batch.add_data_object({}, 'MyClass') # sum of data_objects and references reached - >>> client.batch.shape - (0, 0) - - Or with a context manager and `__call__` method: - - >>> with client.batch(batch_size=3) as batch: - ... batch.add_data_object({}, 'MyClass') - ... batch.add_reference(object_3, 'MyClass', 'myProp', object_4) - ... batch.add_data_object({}, 'MyClass') - ... batch.add_reference(object_1, 'MyClass', 'myProp', object_4) - >>> # flush was called - >>> client.batch.shape - (0, 0) - - Or with a context manager and setter: - - >>> client.batch.batch_size = 3 - >>> with client.batch as batch: - ... batch.add_data_object({}, 'MyClass') - ... batch.add_reference(object_3, 'MyClass', 'myProp', object_4) - ... batch.add_data_object({}, 'MyClass') - ... batch.add_reference(object_1, 'MyClass', 'myProp', object_4) - >>> # flush was called - >>> client.batch.shape - (0, 0) - - For Case III: - Same as Case II but you need to configure or enable 'dynamic' batching. - - >>> client.batch.configure(batch_size=3, dynamic=True) # 'batch_size' must be an valid int - - Or: - - >>> client.batch.batch_size = 3 - >>> client.batch.dynamic = True - - See the documentation of the `configure`( or `__call__`) and the setters for more information - on how/why and what you need to configure/set in order to use a particular Case. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Batch class instance. This defaults to manual creation configuration. - See docs for the `configure` or `__call__` method for different types of configurations. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running weaviate instance. - """ - - # set all protected attributes - self._shutdown_background_event: Optional[threading.Event] = None - self._new_dynamic_batching = True - self._connection = connection - self._objects_batch = ObjectsBatchRequest() - self._reference_batch = ReferenceBatchRequest() - # do not keep too many past values, so it is a better estimation of the throughput is computed for 1 second - self._objects_throughput_frame: Deque[float] = deque(maxlen=5) - self._references_throughput_frame: Deque[float] = deque(maxlen=5) - self._future_pool: List[Future[Tuple[Union[Response, None], int]]] = [] - self._reference_batch_queue: List[ReferenceBatchRequest] = [] - self._callback_lock = threading.Lock() - - # user configurable, need to be public should implement a setter/getter - self._callback: Optional[Callable[[BatchResponse], None]] = check_batch_result - self._weaviate_error_retry: Optional[WeaviateErrorRetryConf] = None - self._batch_size: Optional[int] = 50 - self._creation_time = cast(Real, min(self._connection.timeout_config[1] / 10, 2)) - self._timeout_retries = 3 - self._connection_error_retries = 3 - self._batching_type: Optional[str] = "dynamic" - self._recommended_num_objects = self._batch_size - self._recommended_num_references = self._batch_size - - self.__imported_shards: Set[Shard] = set() - - self._num_workers = 1 - self._consistency_level: Optional[ConsistencyLevel] = None - # thread pool executor - self._executor: Optional[BatchExecutor] = None - - def __call__(self, **kwargs: Any) -> "Batch": - """ - WARNING: This method will be deprecated in the next major release. Use `configure` instead. - - Parameters - ---------- - batch_size : Optional[int], optional - The batch size to be use. This value sets the Batch functionality, if `batch_size` is - None then no auto-creation is done (`callback` and `dynamic` are ignored). If it is a - positive number auto-creation is enabled and the value represents: 1) in case `dynamic` - is False -> the number of data in the Batch (sum of objects and references) when to - auto-create; 2) in case `dynamic` is True -> the initial value for both - `recommended_num_objects` and `recommended_num_references`, by default None - creation_time : Real, optional - How long it should take to create a Batch. Used ONLY for computing dynamic batch sizes. By default None - timeout_retries : int, optional - Number of retries to create a Batch that failed with ReadTimeout, by default 3 - weaviate_error_retries: Optional[WeaviateErrorRetryConf], by default None - How often batch-elements with an error originating from weaviate (for example transformer timeouts) should - be retried and which errors should be ignored and/or included. See documentation for WeaviateErrorRetryConf - for details. - connection_error_retries : int, optional - Number of retries to create a Batch that failed with ConnectionError, by default 3 - callback : Optional[Callable[[dict], None]], optional - A callback function on the results of each (objects and references) batch types. - By default `weaviate.util.check_batch_result`. - dynamic : bool, optional - Whether to use dynamic batching or not, by default False - num_workers : int, optional - The maximal number of concurrent threads to run batch import. Only used for non-MANUAL - batching. i.e. is used only with AUTO or DYNAMIC batching. - By default, the multi-threading is disabled. Use with care to not overload your weaviate instance. - - Returns - ------- - Batch - Updated self. - - Raises - ------ - TypeError - If one of the arguments is of a wrong type. - ValueError - If the value of one of the arguments is wrong. - """ - _Warnings.use_of_client_batch_will_be_removed_in_next_major_release() - return self.configure(**kwargs) - - def configure( - self, - batch_size: Optional[int] = 50, - creation_time: Optional[Real] = None, - timeout_retries: int = 3, - connection_error_retries: int = 3, - weaviate_error_retries: Optional[WeaviateErrorRetryConf] = None, - callback: Optional[Callable[[List[dict]], None]] = check_batch_result, - dynamic: bool = True, - num_workers: int = 1, - consistency_level: Optional[ConsistencyLevel] = None, - ) -> "Batch": - """ - Warnings - -------- - - It has default values and if you want to change only one use a setter instead or - provide all the configurations, both the old and new ones. - - This method will return `None` in the next major release. If you are using the returned - `Batch` object then you should start using the `client.batch` object instead. - - Parameters - ---------- - batch_size : Optional[int], optional - The batch size to be use. This value sets the Batch functionality, if `batch_size` is - None then no auto-creation is done (`callback` and `dynamic` are ignored). If it is a - positive number auto-creation is enabled and the value represents: 1) in case `dynamic` - is False -> the number of data in the Batch (sum of objects and references) when to - auto-create; 2) in case `dynamic` is True -> the initial value for both - `recommended_num_objects` and `recommended_num_references`, by default 50 - creation_time : Real, optional - How long it should take to create a Batch. Used ONLY for computing dynamic batch sizes. By default None - timeout_retries : int, optional - Number of retries to create a Batch that failed with ReadTimeout, by default 3 - connection_error_retries : int, optional - Number of retries to create a Batch that failed with ConnectionError, by default 3 - weaviate_error_retries: WeaviateErrorRetryConf, Optional - How often batch-elements with an error originating from weaviate (for example transformer timeouts) should - be retried and which errors should be ignored and/or included. See documentation for WeaviateErrorRetryConf - for details. - callback : Optional[Callable[[dict], None]], optional - A callback function on the results of each (objects and references) batch types. - By default `weaviate.util.check_batch_result` - dynamic : bool, optional - Whether to use dynamic batching or not, by default True - num_workers : int, optional - The maximal number of concurrent threads to run batch import. Only used for non-MANUAL - batching. i.e. is used only with AUTO or DYNAMIC batching. - By default, the multi-threading is disabled. Use with care to not overload your weaviate instance. - - Returns - ------- - Batch - Updated self. - - Raises - ------ - TypeError - If one of the arguments is of a wrong type. - ValueError - If the value of one of the arguments is wrong. - """ - self.consistency_level = consistency_level - if creation_time is not None: - _check_positive_num(creation_time, "creation_time", Real) - self._creation_time = creation_time - else: - self._creation_time = cast(Real, min(self._connection.timeout_config[1] / 10, 2)) - - _check_non_negative(timeout_retries, "timeout_retries", int) - _check_non_negative(connection_error_retries, "connection_error_retries", int) - - self._callback = callback - - self._timeout_retries = timeout_retries - self._connection_error_retries = connection_error_retries - self._weaviate_error_retry = weaviate_error_retries - # set Batch to manual import - if batch_size is None and not dynamic: - self._batch_size = None - self._batching_type = None - return self - - _check_positive_num(batch_size, "batch_size", int) - _check_positive_num(num_workers, "num_workers", int) - _check_bool(dynamic, "dynamic") - - if self._num_workers != num_workers: - self.flush() - self.shutdown() - self._num_workers = num_workers - self.start() - - self._batch_size = batch_size - - if dynamic is False: # set Batch to auto-commit with fixed batch_size - self._batching_type = "fixed" - else: # else set to 'dynamic' - self._batching_type = "dynamic" - self._recommended_num_objects = 50 if batch_size is None else batch_size - self._recommended_num_references = 50 if batch_size is None else batch_size - if self._shutdown_background_event is None: - self._update_recommended_batch_size() - - self._auto_create() - return self - - def _update_recommended_batch_size(self) -> None: - """Create a background thread that periodically checks how congested the batch queue is.""" - self._shutdown_background_event = threading.Event() - - def periodic_check() -> None: - cluster = Cluster(self._connection) - while ( - self._shutdown_background_event is not None - and not self._shutdown_background_event.is_set() - ): - try: - status = cluster.get_nodes_status() - if "stats" not in status[0] or "ratePerSecond" not in status[0]["stats"]: - self._new_dynamic_batching = False - return - rate = status[0]["batchStats"]["ratePerSecond"] - rate_per_worker = rate / self._num_workers - batch_length = status[0]["batchStats"]["queueLength"] - - if batch_length == 0: # scale up if queue is empty - self._recommended_num_objects = self._recommended_num_objects + min( - self._recommended_num_objects * 2, 25 - ) - else: - ratio = batch_length / rate - if ( - 2.1 > ratio > 1.9 - ): # ideal, send exactly as many objects as weaviate can process - self._recommended_num_objects = rate_per_worker # type: ignore - elif ratio <= 1.9: # we can send more - self._recommended_num_objects = min( - self._recommended_num_objects * 1.5, rate_per_worker * 2 / ratio # type: ignore - ) - elif ratio < 10: # too high, scale down - self._recommended_num_objects = rate_per_worker * 2 / ratio # type: ignore - else: # way too high, stop sending new batches - self._recommended_num_objects = 0 - - refresh_time: float = 2 - except (RequestsHTTPError, ReadTimeout): - refresh_time = 0.1 - - time.sleep(refresh_time) - self._recommended_num_objects = 10 # in case some batch needs to be send afterwards - self._shutdown_background_event = None - - demon = threading.Thread( - target=periodic_check, - daemon=True, - name="batchSizeRefresh", - ) - demon.start() - - def add_data_object( - self, - data_object: dict, - class_name: str, - uuid: Optional[UUID] = None, - vector: Optional[Sequence] = None, - tenant: Optional[str] = None, - ) -> str: - """ - Add one object to this batch. - NOTE: If the UUID of one of the objects already exists then the existing object will be - replaced by the new object. - - Parameters - ---------- - data_object : dict - Object to be added as a dict datatype. - class_name : str - The name of the class this object belongs to. - uuid : Optional[UUID], optional - The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. - If it is None an UUIDv4 will generated, by default None - vector: Sequence or None, optional - The embedding of the object that should be validated. - Can be used when: - - a class does not have a vectorization module. - - The given vector was generated using the _identical_ vectorization module that is configured for the - class. In this case this vector takes precedence. - - Supported types are `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, - by default None. - - Returns - ------- - str - The UUID of the added object. If one was not provided a UUIDv4 will be generated. - - Raises - ------ - TypeError - If an argument passed is not of an appropriate type. - ValueError - If 'uuid' is not of a proper form. - """ - uuid = self._objects_batch.add( - class_name=_capitalize_first_letter(class_name), - data_object=data_object, - uuid=uuid, - vector=vector, - tenant=tenant, - ) - - self.__imported_shards.add(Shard(class_name, tenant)) - - if self._batching_type: - self._auto_create() - - return uuid - - def add_reference( - self, - from_object_uuid: UUID, - from_object_class_name: str, - from_property_name: str, - to_object_uuid: UUID, - to_object_class_name: Optional[str] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Add one reference to this batch. - - Parameters - ---------- - from_object_uuid : UUID - The UUID of the object, as an uuid.UUID object or str, that should reference another object. - It can be a Weaviate beacon or Weaviate href. - from_object_class_name : str - The name of the class that should reference another object. - from_property_name : str - The name of the property that contains the reference. - to_object_uuid : UUID - The UUID of the object, as an uuid.UUID object or str, that is actually referenced. - It can be a Weaviate beacon or Weaviate href. - to_object_class_name : Optional[str], optional - The referenced object class name to which to add the reference (with UUID - `to_object_uuid`), it is included in Weaviate 1.14.0, where all objects are namespaced - by class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - tenant: str, optional - Name of the tenant. - - Raises - ------ - TypeError - If arguments are not of type str. - ValueError - If 'uuid' is not valid or cannot be extracted. - """ - - is_server_version_14 = self._connection.server_version >= "1.14" - - if to_object_class_name is None and is_server_version_14: - warnings.warn( - message=BATCH_REF_DEPRECATION_NEW_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if to_object_class_name is not None: - if not is_server_version_14: - warnings.warn( - message=BATCH_REF_DEPRECATION_OLD_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - to_object_class_name = None - if is_server_version_14: - if not isinstance(to_object_class_name, str): - raise TypeError( - "'to_object_class_name' must be of type str or None. " - f"Given type: {type(to_object_class_name)}" - ) - to_object_class_name = _capitalize_first_letter(to_object_class_name) - - self._reference_batch.add( - from_object_class_name=_capitalize_first_letter(from_object_class_name), - from_object_uuid=from_object_uuid, - from_property_name=from_property_name, - to_object_uuid=to_object_uuid, - to_object_class_name=to_object_class_name, - tenant=tenant, - ) - - if self._batching_type: - self._auto_create() - - def _create_data( - self, - data_type: str, - batch_request: BatchRequest, - ) -> Response: - """ - Create data in batches, either Objects or References. This does NOT guarantee - that each batch item (only Objects) is added/created. This can lead to a successful - batch creation but unsuccessful per batch item creation. See the Examples below. - - Parameters - ---------- - data_type : str - The data type of the BatchRequest, used to save time for not checking the type of the - BatchRequest. - batch_request : weaviate.batch.BatchRequest - Contains all the data objects that should be added in one batch. - Note: Should be a sub-class of BatchRequest since BatchRequest - is just an abstract class, e.g. ObjectsBatchRequest, ReferenceBatchRequest - - Returns - ------- - requests.Response - The requests response. - - Raises - ------ - requests.ReadTimeout - If the request time-outed. - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - params: Dict[str, str] = {} - if self._consistency_level is not None: - params["consistency_level"] = self._consistency_level.value - - try: - timeout_count = connection_count = batch_error_count = 0 - while True: - try: - response = self._connection.post( - path="/batch/" + data_type, - weaviate_object=batch_request.get_request_body(), - params=params, - ) - except ReadTimeout as error: - _batch_create_error_handler( - retry=timeout_count, - max_retries=self._timeout_retries, - error=error, - ) - timeout_count += 1 - batch_request = self._batch_retry_after_timeout(data_type, batch_request) - # All elements have been added successfully. The timeout occurred while receiving the answer. - if len(batch_request) == 0: - response = Response() - response.status_code = 200 - response.elapsed = datetime.timedelta( - self._connection.timeout_config[1] + 5 - ) - break - - except RequestsConnectionError as error: - _batch_create_error_handler( - retry=connection_count, - max_retries=self._connection_error_retries, - error=error, - ) - connection_count += 1 - else: - response_json = _decode_json_response_list(response, "batch response") - assert response_json is not None - if ( - self._weaviate_error_retry is not None - and batch_error_count < self._weaviate_error_retry.number_retries - ): - batch_to_retry, response_json_successful = self._retry_on_error( - response_json, data_type - ) - if len(batch_to_retry) > 0: - self._run_callback(response_json_successful) - - batch_error_count += 1 - batch_request = batch_to_retry - continue # run the request again, but only with objects that had errors - - self._run_callback(response_json) - break - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Batch was not added to weaviate.") from conn_err - except ReadTimeout: - message = ( - f"The '{data_type}' creation was cancelled because it took " - f"longer than the configured timeout of {self._connection.timeout_config[1]}s. " - f"Try reducing the batch size (currently {len(batch_request)}) to a lower value. " - "Aim to on average complete batch request within less than 10s" - ) - raise ReadTimeout(message) from None - if response.status_code == 200: - return response - raise UnexpectedStatusCodeException(f"Create {data_type} in batch", response) - - def _run_callback(self, response: BatchResponse) -> None: - if self._callback is None: - return - # We don't know if user-supplied functions are thread-safe - with self._callback_lock: - self._callback(response) - - def _batch_retry_after_timeout( - self, data_type: str, batch_request: BatchRequest - ) -> BatchRequest: - """ - Readds items (objects or references) that were not added due to a timeout. - - Parameters - ---------- - data_type : str - The Batch Request type, can be either 'objects' or 'references'. - batch_request : BatchRequest - The Batch Request that TimeOuted. - - Returns - ------- - BatchRequest - New Batch Request with objects that were not added or not updated. - """ - - if data_type == "objects": - assert isinstance(batch_request, ObjectsBatchRequest) - return self._readd_objects_after_timeout(batch_request) - else: - assert isinstance(batch_request, ReferenceBatchRequest) - return self._readd_references_after_timeout(batch_request) - - def _readd_objects_after_timeout( - self, batch_request: ObjectsBatchRequest - ) -> ObjectsBatchRequest: - """ - Read all objects that were not created or updated because of a TimeOut error. - - Parameters - ---------- - batch_request : ObjectsBatchRequest - The ObjectsBatchRequest from which to check if items where created or updated. - - Returns - ------- - ObjectsBatchRequest - New ObjectsBatchRequest with only the objects that were not created or updated. - """ - - new_batch = ObjectsBatchRequest() - for obj in batch_request.get_request_body()["objects"]: - class_name = obj["class"] - tenant = obj.get("tenant", None) - uuid = obj["id"] - params = {"tenant": tenant} if tenant is not None else None - - response_head = self._connection.head( - path="/objects/" + class_name + "/" + uuid, - params=params, - ) - - if response_head.status_code == 404: - new_batch.add( - class_name=_capitalize_first_letter(class_name), - data_object=obj["properties"], - uuid=uuid, - vector=obj.get("vector", None), - ) - continue - - # object might already exist and needs to be overwritten in case of an update - response = self._connection.get( - path="/objects/" + class_name + "/" + uuid, - params=params, - ) - - obj_weav = _decode_json_response_dict(response, "Re-add objects") - assert obj_weav is not None - if obj_weav["properties"] != obj["properties"] or obj.get( - "vector", None - ) != obj_weav.get("vector", None): - new_batch.add( - class_name=_capitalize_first_letter(class_name), - data_object=obj["properties"], - uuid=uuid, - vector=obj.get("vector", None), - tenant=tenant, - ) - return new_batch - - def _readd_references_after_timeout( - self, batch_request: ReferenceBatchRequest - ) -> ReferenceBatchRequest: - """ - Read all objects that were not created or updated because of a TimeOut error. - - Parameters - ---------- - batch_request : ReferenceBatchRequest - The ReferenceBatchRequest from which to check if items where created or updated. - - Returns - ------- - ReferenceBatchRequest - New ReferenceBatchRequest with only the references that were not created or updated. - """ - - new_batch = ReferenceBatchRequest() - for ref in batch_request.get_request_body(): - new_batch.add( - from_object_class_name=ref["from_object_class_name"], - from_object_uuid=ref["from_object_uuid"], - from_property_name=ref["from_property_name"], - to_object_uuid=ref["to_object_uuid"], - to_object_class_name=ref.get("to_object_class_name", None), - ) - return new_batch - - def create_objects(self) -> list: - """ - Creates multiple Objects at once in Weaviate. This does not guarantee that each batch item - is added/created to the Weaviate server. This can lead to a successful batch creation but - unsuccessful per batch item creation. See the example bellow. - NOTE: If the UUID of one of the objects already exists then the existing object will be - replaced by the new object. - - Examples - -------- - Here `client` is an instance of the `weaviate.Client`. - - Add objects to the object batch. - - >>> client.batch.add_data_object({}, 'NonExistingClass') - >>> client.batch.add_data_object({}, 'ExistingClass') - - Note that 'NonExistingClass' is not present in the client's schema and 'ExistingObject' - is present and has no proprieties. 'client.batch.add_data_object' does not raise an - exception because the objects added meet the required criteria (See the documentation of - the 'weaviate.Batch.add_data_object' method for more information). - - >>> result = client.batch.create_objects(batch) - - Successful batch creation even if one data object is inconsistent with the client's schema. - We can find out more about what objects were successfully created by analyzing the 'result' - variable. - - >>> import json - >>> print(json.dumps(result, indent=4)) - [ - { - "class": "NonExistingClass", - "creationTimeUnix": 1614852753747, - "id": "154cbccd-89f4-4b29-9c1b-001a3339d89a", - "properties": {}, - "deprecations": null, - "result": { - "errors": { - "error": [ - { - "message": "class 'NonExistingClass' not present in schema, - class NonExistingClass not present" - } - ] - } - } - }, - { - "class": "ExistingClass", - "creationTimeUnix": 1614852753746, - "id": "b7b1cfbe-20da-496c-b932-008d35805f26", - "properties": {}, - "vector": [ - -0.05244319, - ... - 0.076136276 - ], - "deprecations": null, - "result": {} - } - ] - - - As it can be noticed the first object from the batch was not added/created, but the batch - was successfully created. The batch creation can be successful even if all the objects were - NOT created. Check the status of the batch objects to find which object and why creation - failed. Alternatively use 'client.data_object.create' for Object creation that throw an - error if data item is inconsistent or creation/addition failed. - - To check the results of batch creation when using the auto-creation Batch, use a 'callback' - (see the docs `configure` or `__call__` method for more information). - - Returns - ------- - list - A list with the status of every object that was created. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - - if len(self._objects_batch) != 0: - _Warnings.manual_batching() - - response = self._create_data( - data_type="objects", - batch_request=self._objects_batch, - ) - self._objects_batch = ObjectsBatchRequest() - - self._objects_throughput_frame.append( - len(self._objects_batch) / response.elapsed.total_seconds() - ) - obj_per_second = sum(self._objects_throughput_frame) / len( - self._objects_throughput_frame - ) - - self._recommended_num_objects = max( - round(obj_per_second * float(self._creation_time)), 1 - ) - - res = _decode_json_response_list(response, "batch add objects") - assert res is not None - return res - return [] - - def create_references(self) -> list: - """ - Creates multiple References at once in Weaviate. - Adding References in batch is faster but it ignores validations like class name - and property name, resulting in a SUCCESSFUL reference creation of a nonexistent object - types and/or a nonexistent properties. If the consistency of the References is wanted - use 'client.data_object.reference.add' to have additional validation against the - weaviate schema. See Examples below. - - Examples - -------- - Here `client` is an instance of the `weaviate.Client`. - - Object that does not exist in weaviate. - - >>> object_1 = '154cbccd-89f4-4b29-9c1b-001a3339d89d' - - Objects that exist in weaviate. - - >>> object_2 = '154cbccd-89f4-4b29-9c1b-001a3339d89c' - >>> object_3 = '254cbccd-89f4-4b29-9c1b-001a3339d89a' - >>> object_4 = '254cbccd-89f4-4b29-9c1b-001a3339d89b' - - >>> client.batch.add_reference(object_1, 'NonExistingClass', 'existsWith', object_2) - >>> client.batch.add_reference(object_3, 'ExistingClass', 'existsWith', object_4) - - Both references were added to the batch request without error because they meet the - required criteria (See the documentation of the 'weaviate.Batch.add_reference' method - for more information). - - >>> result = client.batch.create_references() - - As it can be noticed the reference batch creation is successful (no error thrown). Now we - can inspect the 'result'. - - >>> import json - >>> print(json.dumps(result, indent=4)) - [ - { - "from": "weaviate://localhost/NonExistingClass/ - 154cbccd-89f4-4b29-9c1b-001a3339d89a/existsWith", - "to": "weaviate://localhost/154cbccd-89f4-4b29-9c1b-001a3339d89b", - "result": { - "status": "SUCCESS" - } - }, - { - "from": "weaviate://localhost/ExistingClass/ - 254cbccd-89f4-4b29-9c1b-001a3339d89a/existsWith", - "to": "weaviate://localhost/254cbccd-89f4-4b29-9c1b-001a3339d89b", - "result": { - "status": "SUCCESS" - } - } - ] - - Both references were added successfully but one of them is corrupted (links two objects - of nonexisting class and one of the objects is not yet created). To make use of the - validation, crete each references individually (see the client.data_object.reference.add - method). - - Returns - ------- - list - A list with the status of every reference added. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - - if len(self._reference_batch) != 0: - _Warnings.manual_batching() - - response = self._create_data( - data_type="references", - batch_request=self._reference_batch, - ) - self._reference_batch = ReferenceBatchRequest() - - self._references_throughput_frame.append( - len(self._reference_batch) / response.elapsed.total_seconds() - ) - ref_per_sec = sum(self._references_throughput_frame) / len( - self._references_throughput_frame - ) - - self._recommended_num_references = round(ref_per_sec * float(self._creation_time)) - - res = _decode_json_response_list(response, "Create references") - assert res is not None - return res - return [] - - def _flush_in_thread( - self, - data_type: str, - batch_request: BatchRequest, - ) -> Tuple[Optional[Response], int]: - """ - Flush BatchRequest in current thread/process. - - Parameters - ---------- - data_type : str - The data type of the BatchRequest, used to save time for not checking the type of the - BatchRequest. - batch_request : weaviate.batch.BatchRequest - Contains all the data objects that should be added in one batch. - Note: Should be a sub-class of BatchRequest since BatchRequest - is just an abstract class, e.g. ObjectsBatchRequest, ReferenceBatchRequest - - Returns - ------- - Tuple[requests.Response, int] - The request response and number of items sent with the BatchRequest as tuple. - """ - - if len(batch_request) != 0: - response = self._create_data( - data_type=data_type, - batch_request=batch_request, - ) - return response, len(batch_request) - return None, 0 - - def _send_batch_requests(self, force_wait: bool) -> None: - """ - Send BatchRequest in a separate thread/process. This methods submits a task to create only - the ObjectsBatchRequests to the BatchExecutor and adds the ReferencesBatchRequests to a - queue, then it carries on in the main thread until `num_workers` tasks have been submitted. - When we have reached number of tasks to be equal to `num_workers` it waits for all the - tasks to finish and handles the responses. After all ObjectsBatchRequests have been handled - it created separate tasks for each ReferencesBatchRequests, then it handles their responses - as well. This mechanism of creating References after Objects is constructed in this manner - to eliminate potential error when creating references from a object that does not yet - exists (object that is part of another task). - - Parameters - ---------- - force_wait : bool - Whether to wait on all created tasks even if we do not have `num_workers` tasks created - """ - if self._executor is None: - self.start() - elif self._executor.is_shutdown(): - warnings.warn( - message=BATCH_EXECUTOR_SHUTDOWN_W, - category=RuntimeWarning, - stacklevel=1, - ) - self.start() - - assert self._executor is not None - future = self._executor.submit( - self._flush_in_thread, - data_type="objects", - batch_request=self._objects_batch, - ) - - self._future_pool.append(future) - if len(self._reference_batch) > 0: - self._reference_batch_queue.append(self._reference_batch) - - self._objects_batch = ObjectsBatchRequest() - self._reference_batch = ReferenceBatchRequest() - - if not force_wait and self._num_workers > 1 and len(self._future_pool) < self._num_workers: - return - timeout_occurred = False - for done_future in as_completed(self._future_pool): - response_objects, nr_objects = done_future.result() - - # handle objects response - if response_objects is not None: - self._objects_throughput_frame.append( - nr_objects / response_objects.elapsed.total_seconds() - ) - - else: - timeout_occurred = True - - if timeout_occurred and self._recommended_num_objects is not None: - self._recommended_num_objects = max(self._recommended_num_objects // 2, 1) - elif ( - len(self._objects_throughput_frame) != 0 - and self._recommended_num_objects is not None - and not self._new_dynamic_batching - ): - obj_per_second = ( - sum(self._objects_throughput_frame) / len(self._objects_throughput_frame) * 0.75 - ) - self._recommended_num_objects = max( - min( - round(obj_per_second * float(self._creation_time)), - self._recommended_num_objects + 250, - ), - 1, - ) - - # Create references after all the objects have been created - reference_future_pool = [] - for reference_batch in self._reference_batch_queue: - future = self._executor.submit( - self._flush_in_thread, - data_type="references", - batch_request=reference_batch, - ) - reference_future_pool.append(future) - - timeout_occurred = False - for done_future in as_completed(reference_future_pool): - response_references, nr_references = done_future.result() - - # handle references response - if response_references is not None: - self._references_throughput_frame.append( - nr_references / response_references.elapsed.total_seconds() - ) - else: - timeout_occurred = True - - if timeout_occurred and self._recommended_num_references is not None: - self._recommended_num_references = max(self._recommended_num_references // 2, 1) - elif ( - len(self._references_throughput_frame) != 0 - and self._recommended_num_references is not None - ): - ref_per_sec = sum(self._references_throughput_frame) / len( - self._references_throughput_frame - ) - self._recommended_num_references = min( - round(ref_per_sec * float(self._creation_time)), - self._recommended_num_references * 2, - ) - - self._future_pool = [] - self._reference_batch_queue = [] - return - - def _auto_create(self) -> None: - """ - Auto create both objects and references in the batch. This protected method works with a - fixed batch size and with dynamic batching. For a 'fixed' batching type it auto-creates - when the sum of both objects and references equals batch_size. For dynamic batching it - creates both batch requests when only one is full. - """ - - # greater or equal in case the self._batch_size is changed manually - if self._batching_type == "fixed": - assert self._batch_size is not None - if sum(self.shape) >= self._batch_size: - self._send_batch_requests(force_wait=False) - return - elif self._batching_type == "dynamic": - if ( - self.num_objects() >= self._recommended_num_objects - or self.num_references() >= self._recommended_num_references - ): - while self._recommended_num_objects == 0: - time.sleep(1) # block if weaviate is overloaded - - self._send_batch_requests(force_wait=False) - return - # just in case - raise ValueError(f'Unsupported batching type "{self._batching_type}"') - - def flush(self) -> None: - """ - Flush both objects and references to the Weaviate server and call the callback function - if one is provided. (See the docs for `configure` or `__call__` for how to set one.) - """ - self._send_batch_requests(force_wait=True) - - def delete_objects( - self, - class_name: str, - where: dict, - output: str = "minimal", - dry_run: bool = False, - tenant: Optional[str] = None, - ) -> dict: - """ - Delete objects that match the 'match' in batch. - - Parameters - ---------- - class_name : str - The class name for which to delete objects. - where : dict - The content of the `where` filter used to match objects that should be deleted. - output : str, optional - The control of the verbosity of the output, possible values: - - "minimal" : The result only includes counts. Information about objects is omitted if - the deletes were successful. Only if an error occurred will the object be described. - - "verbose" : The result lists all affected objects with their ID and deletion status, - including both successful and unsuccessful deletes. - By default "minimal" - dry_run : bool, optional - If True, objects will not be deleted yet, but merely listed, by default False - - Examples - -------- - - If we want to delete all the data objects that contain the word 'weather' we can do it like - this: - - >>> result = client.batch.delete_objects( - ... class_name='Dataset', - ... output='verbose', - ... dry_run=False, - ... where={ - ... 'operator': 'Equal', - ... 'path': ['description'], - ... 'valueText': 'weather' - ... } - ... ) - >>> print(json.dumps(result, indent=4)) - { - "dryRun": false, - "match": { - "class": "Dataset", - "where": { - "operands": null, - "operator": "Equal", - "path": [ - "description" - ], - "valueText": "weather" - } - }, - "output": "verbose", - "results": { - "failed": 0, - "limit": 10000, - "matches": 2, - "objects": [ - { - "id": "1eb28f69-c66e-5411-bad4-4e14412b65cd", - "status": "SUCCESS" - }, - { - "id": "da217bdd-4c7c-5568-9576-ebefe17688ba", - "status": "SUCCESS" - } - ], - "successful": 2 - } - } - - Returns - ------- - dict - The result/status of the batch delete. - """ - - if not isinstance(class_name, str): - raise TypeError(f"'class_name' must be of type str. Given type: {type(class_name)}.") - if not isinstance(where, dict): - raise TypeError(f"'where' must be of type dict. Given type: {type(where)}.") - if not isinstance(output, str): - raise TypeError(f"'output' must be of type str. Given type: {type(output)}.") - if not isinstance(dry_run, bool): - raise TypeError(f"'dry_run' must be of type bool. Given type: {type(dry_run)}.") - - params: Dict[str, str] = {} - if self._consistency_level is not None: - params["consistency_level"] = self._consistency_level.value - if tenant is not None: - params["tenant"] = tenant - - payload = { - "match": { - "class": _capitalize_first_letter(class_name), - "where": _clean_delete_objects_where(where), - }, - "output": output, - "dryRun": dry_run, - } - - try: - response = self._connection.delete( - path="/batch/objects", - weaviate_object=payload, - params=params, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Batch delete was not successful.") from conn_err - res = _decode_json_response_dict(response, "Delete in batch") - assert res is not None - return res - - def num_objects(self) -> int: - """ - Get current number of objects in the batch. - - Returns - ------- - int - The number of objects in the batch. - """ - - return len(self._objects_batch) - - def num_references(self) -> int: - """ - Get current number of references in the batch. - - Returns - ------- - int - The number of references in the batch. - """ - - return len(self._reference_batch) - - def pop_object(self, index: int = -1) -> dict: - """ - Remove and return the object at index (default last). - - Parameters - ---------- - index : int, optional - The index of the object to pop, by default -1 (last item). - - Returns - ------- - dict - The popped object. - - Raises - ------- - IndexError - If batch is empty or index is out of range. - """ - - return self._objects_batch.pop(index) - - def pop_reference(self, index: int = -1) -> dict: - """ - Remove and return the reference at index (default last). - - Parameters - ---------- - index : int, optional - The index of the reference to pop, by default -1 (last item). - - Returns - ------- - dict - The popped reference. - - Raises - ------- - IndexError - If batch is empty or index is out of range. - """ - - return self._reference_batch.pop(index) - - def empty_objects(self) -> None: - """ - Remove all the objects from the batch. - """ - - self._objects_batch.empty() - - def empty_references(self) -> None: - """ - Remove all the references from the batch. - """ - - self._reference_batch.empty() - - def is_empty_objects(self) -> bool: - """ - Check if batch contains any objects. - - Returns - ------- - bool - Whether the Batch object list is empty. - """ - - return self._objects_batch.is_empty() - - def is_empty_references(self) -> bool: - """ - Check if batch contains any references. - - Returns - ------- - bool - Whether the Batch reference list is empty. - """ - - return self._reference_batch.is_empty() - - @property - def shape(self) -> Tuple[int, int]: - """ - Get current number of objects and references in the batch. - - Returns - ------- - Tuple[int, int] - The number of objects and references, respectively, in the batch as a tuple, - i.e. returns (number of objects, number of references). - """ - - return (len(self._objects_batch), len(self._reference_batch)) - - @property - def batch_size(self) -> Optional[int]: - """ - Setter and Getter for `batch_size`. - - Parameters - ---------- - value : Optional[int] - Setter ONLY: The new value for the batch_size. If NOT None it will try to auto-create - the existing data if it meets the requirements. If previous value was None then it will - be set to new value and will change the batching type to auto-create with dynamic set - to False. See the documentation for `configure` or `__call__` for more info. - If recommended_num_objects is None then it is initialized with the new value of the - batch_size (same for references). - - Returns - ------- - Optional[int] - Getter ONLY: The current value of the batch_size. It is NOT the current number of - data in the Batch. See the documentation for `configure` or `__call__` for more info. - - Raises - ------ - TypeError - Setter ONLY: If the new value is not of type int. - ValueError - Setter ONLY: If the new value has a non positive value. - """ - - return self._batch_size - - @batch_size.setter - def batch_size(self, value: Optional[int]) -> None: - if value is None: - self._batch_size = None - self._batching_type = None - return - - _check_positive_num(value, "batch_size", int) - self._batch_size = value - if self._batching_type is None: - self._batching_type = "fixed" - if self._recommended_num_objects is None: - self._recommended_num_objects = value - if self._recommended_num_references is None: - self._recommended_num_references = value - self._auto_create() - - @property - def dynamic(self) -> bool: - """ - Setter and Getter for `dynamic`. - - Parameters - ---------- - value : bool - Setter ONLY: En/dis-able the dynamic batching. If batch_size is None the value is not - set, otherwise it will set the dynamic to new value and auto-create if it meets the - requirements. - - Returns - ------- - bool - Getter ONLY: Wether the dynamic batching is enabled. - - Raises - ------ - TypeError - Setter ONLY: If the new value is not of type bool. - """ - - return self._batching_type == "dynamic" - - @dynamic.setter - def dynamic(self, value: bool) -> None: - if self._batching_type is None: - return - - _check_bool(value, "dynamic") - - if value is True: - self._batching_type = "dynamic" - else: - self._batching_type = "fixed" - self._auto_create() - - @property - def consistency_level(self) -> Union[str, None]: - return self._consistency_level.value if self._consistency_level is not None else None - - @consistency_level.setter - def consistency_level(self, x: Optional[Union[ConsistencyLevel, str]]) -> None: - self._consistency_level = ConsistencyLevel(x) if x is not None else None - - @property - def recommended_num_objects(self) -> Optional[int]: - """ - The recommended number of objects per batch. If None then it could not be computed. - - Returns - ------- - Optional[int] - The recommended number of objects per batch. If None then it could not be computed. - """ - - return self._recommended_num_objects - - @property - def recommended_num_references(self) -> Optional[int]: - """ - The recommended number of references per batch. If None then it could not be computed. - - Returns - ------- - Optional[int] - The recommended number of references per batch. If None then it could not be computed. - """ - - return self._recommended_num_references - - def start(self) -> "Batch": - """ - Start the BatchExecutor if it was closed. - - Returns - ------- - Batch - Updated self. - """ - - if self._executor is None or self._executor.is_shutdown(): - self._executor = BatchExecutor(max_workers=self._num_workers) - - if self._batching_type == "dynamic" and ( - self._shutdown_background_event is None or self._shutdown_background_event.is_set() - ): - self._update_recommended_batch_size() - - return self - - def shutdown(self) -> None: - """ - Shutdown the BatchExecutor. - """ - if not (self._executor is None or self._executor.is_shutdown()): - self._executor.shutdown() - - if self._shutdown_background_event is not None: - self._shutdown_background_event.set() - - def __enter__(self) -> "Batch": - return self.start() - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.flush() - self.shutdown() - - def wait_for_vector_indexing( - self, shards: Optional[List[Shard]] = None, how_many_failures: int = 5 - ) -> None: - """Wait for the all the vectors of the batch imported objects to be indexed. - - Upon network error, it will retry to get the shards' status for `how_many_failures` times - with exponential backoff (2**n seconds with n=0,1,2,...,how_many_failures). - - Parameters - ---------- - shards {Optional[List[Shard]]} -- The shards to check the status of. If None it will - check the status of all the shards of the imported objects in the batch. - how_many_failures {int} -- How many times to try to get the shards' status before - raising an exception. Default 5. - """ - if shards is not None and not isinstance(shards, list): - raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") - if shards is not None and not isinstance(shards[0], Shard): - raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") - - def is_ready(how_many: int) -> bool: - try: - return all( - all(self._get_shards_readiness(shard)) - for shard in shards or self.__imported_shards - ) - except RequestsConnectionError as e: - print( - f"Error while getting class shards statuses: {e}, trying again with 2**n={2**how_many}s exponential backoff with n={how_many}" - ) - if how_many_failures == how_many: - raise e - time.sleep(2**how_many) - return is_ready(how_many + 1) - - while not is_ready(0): - print("Waiting for async indexing to finish...") - time.sleep(0.25) - - def _get_shards_readiness(self, shard: Shard) -> List[bool]: - if not isinstance(shard.class_name, str): - raise TypeError( - "'class_name' argument must be of type `str`! " - f"Given type: {type(shard.class_name)}." - ) - - path = f"/schema/{_capitalize_first_letter(shard.class_name)}/shards{'' if shard.tenant is None else f'?tenant={shard.tenant}'}" - - try: - response = self._connection.get(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Class shards' status could not be retrieved due to connection error." - ) from conn_err - - res = _decode_json_response_list(response, "Get shards' status") - assert res is not None - return [ - (cast(str, shard.get("status")) == "READY") - & (cast(int, shard.get("vectorQueueSize")) == 0) - for shard in res - ] - - @property - def creation_time(self) -> Real: - """ - Setter and Getter for `creation_time`. - - Parameters - ---------- - value : Real - Setter ONLY: Set new value to creation_time. The recommended_num_objects/references - values are updated to this new value. If the batch_size is not None it will auto-create - the batch if the requirements are met. - - Returns - ------- - Real - Getter ONLY: The `creation_time` value. - - Raises - ------ - TypeError - Setter ONLY: If the new value is not of type Real. - ValueError - Setter ONLY: If the new value has a non positive value. - """ - - return self._creation_time - - @creation_time.setter - def creation_time(self, value: Real) -> None: - _check_positive_num(value, "creation_time", Real) - if self._recommended_num_references is not None: - self._recommended_num_references = round( - self._recommended_num_references * float(value) / float(self._creation_time) - ) - if self._recommended_num_objects is not None: - self._recommended_num_objects = round( - self._recommended_num_objects * float(value) / float(self._creation_time) - ) - self._creation_time = value - if self._batching_type: - self._auto_create() - - @property - def timeout_retries(self) -> int: - """ - Setter and Getter for `timeout_retries`. - - Properties - ---------- - value : int - Setter ONLY: The new value for `timeout_retries`. - - Returns - ------- - int - Getter ONLY: The `timeout_retries` value. - - Raises - ------ - TypeError - Setter ONLY: If the new value is not of type int. - ValueError - Setter ONLY: If the new value has a non positive value. - """ - - return self._timeout_retries - - @timeout_retries.setter - def timeout_retries(self, value: int) -> None: - _check_non_negative(value, "timeout_retries", int) - self._timeout_retries = value - - @property - def connection_error_retries(self) -> int: - """ - Setter and Getter for `connection_error_retries`. - - Properties - ---------- - value : int - Setter ONLY: The new value for `connection_error_retries`. - - Returns - ------- - int - Getter ONLY: The `connection_error_retries` value. - - Raises - ------ - TypeError - Setter ONLY: If the new value is not of type int. - ValueError - Setter ONLY: If the new value has a non positive value. - """ - - return self._connection_error_retries - - @connection_error_retries.setter - def connection_error_retries(self, value: int) -> None: - _check_non_negative(value, "connection_error_retries", int) - self._connection_error_retries = value - - def _retry_on_error( - self, response: BatchResponse, data_type: str - ) -> Tuple[BatchRequestType, BatchResponse]: - if data_type == "objects": - new_batch: Union[ObjectsBatchRequest, ReferenceBatchRequest] = ObjectsBatchRequest() - else: - new_batch = ReferenceBatchRequest() - assert self._weaviate_error_retry is not None - successful_responses = new_batch.add_failed_objects_from_response( - response, - self._weaviate_error_retry.errors_to_exclude, - self._weaviate_error_retry.errors_to_include, - ) - return new_batch, successful_responses - - -N = TypeVar("N", bound=Union[int, float, Real]) - - -def _check_non_negative(value: N, arg_name: str, data_type: Type[N]) -> None: - """ - Check if the `value` of the `arg_name` is a non-negative number. - - Parameters - ---------- - value : N (int, float, Real) - The value to check. - arg_name : str - The name of the variable from the original function call. Used for error message. - data_type : Type[N] - The data type to check for. - - Raises - ------ - TypeError - If the `value` is not of type `data_type`. - ValueError - If the `value` has a negative value. - """ - - if not isinstance(value, data_type) or isinstance(value, bool): - raise TypeError(f"'{arg_name}' must be of type {data_type}.") - if value < 0: - raise ValueError(f"'{arg_name}' must be positive, i.e. greater or equal that zero (>=0).") - - -def _check_bool(value: bool, arg_name: str) -> None: - """ - Check if bool. - - Parameters - ---------- - value : bool - The value to check. - arg_name : str - The name of the variable from the original function call. Used for error message. - - Raises - ------ - TypeError - If the `value` is not of type bool. - """ - - if not isinstance(value, bool): - raise TypeError(f"'{arg_name}' must be of type bool.") - - -def _batch_create_error_handler(retry: int, max_retries: int, error: Exception) -> None: - """ - Handle errors that occur in Batch creation. This function is going to re-raise the error if - number of re-tries was reached. - Parameters - ---------- - retry : int - Current number of attempted request calls. - max_retries : int - Maximum number of attempted request calls. - error : Exception - The exception that occurred (to be re-raised if needed). - Raises - ------ - Exception - The caught exception. - """ - - if retry >= max_retries: - raise error - print( - f"[ERROR] Batch {error.__class__.__name__} Exception occurred! Retrying in " - f"{(retry + 1) * 2}s. [{retry + 1}/{max_retries}]", - file=sys.stderr, - flush=True, - ) - time.sleep((retry + 1) * 2) - - -def _clean_delete_objects_where(where: dict) -> dict: - """Converts the Python-defined where filter type into the Weaviate-defined - where filter type used in the Batch REST request endpoint. - - Parameters - ---------- - where : dict - The Python-defined where filter. - - Returns - ------- - dict - The Weaviate-defined where filter. - """ - if "path" in where: - py_value_type = _find_value_type(where) - weaviate_value_type = _convert_value_type(py_value_type) - if "operator" not in where: - raise ValueError( - "Where filter is missing required field `operator`." f" Given: {where}" - ) - if where["operator"] not in WHERE_OPERATORS: - raise ValueError( - f"Operator {where['operator']} is not allowed. " - f"Allowed operators are: {WHERE_OPERATORS}" - ) - operator = where["operator"] - if "Contains" in operator and "Array" not in weaviate_value_type: - raise ValueError( - f"Operator '{operator}' is not supported for value type '{weaviate_value_type}'. Supported value types are: {VALUE_ARRAY_TYPES}" - ) - where[weaviate_value_type] = where.pop(py_value_type) - elif "operands" in where: - where["operands"] = [_clean_delete_objects_where(operand) for operand in where["operands"]] - else: - raise ValueError( - "Where is missing required fields `path` or `operands`." f" Given: {where}" - ) - return where - - -def _convert_value_type(_type: str) -> str: - """Converts the Python-defined where filter type into the Weaviate-defined - where filter type used in the Batch REST request endpoint. - - Parameters - ---------- - _type : str - The Python-defined where filter type. - - Returns - ------- - str - The Weaviate-defined where filter type. - """ - if _type == "valueTextList": - return "valueTextArray" - elif _type == "valueStringList": - return "valueStringArray" - elif _type == "valueIntList": - return "valueIntArray" - elif _type == "valueNumberList": - return "valueNumberArray" - elif _type == "valueBooleanList": - return "valueBooleanList" - elif _type == "valueDateList": - return "valueDateArray" - else: - return _type diff --git a/weaviate/batch/requests.py b/weaviate/batch/requests.py deleted file mode 100644 index a31d8f472..000000000 --- a/weaviate/batch/requests.py +++ /dev/null @@ -1,336 +0,0 @@ -""" -BatchRequest class definitions. -""" - -import copy -from abc import ABC, abstractmethod -from typing import List, Sequence, Optional, Dict, Any, Union -from uuid import uuid4 - -from weaviate.util import get_valid_uuid, get_vector -from weaviate.types import UUID - -BatchResponse = List[Dict[str, Any]] - - -class BatchRequest(ABC): - """ - BatchRequest abstract class used as a interface for batch requests. - """ - - def __init__(self) -> None: - self._items: List[Dict[str, Any]] = [] - - def __len__(self) -> int: - return len(self._items) - - def is_empty(self) -> bool: - """ - Check if BatchRequest is empty. - - Returns - ------- - bool - Whether the BatchRequest is empty. - """ - - return len(self._items) == 0 - - def empty(self) -> None: - """ - Remove all the items from the BatchRequest. - """ - - self._items = [] - - def pop(self, index: int = -1) -> dict: - """ - Remove and return item at index (default last). - - Parameters - ---------- - index : int, optional - The index of the item to pop, by default -1 (last item). - - Returns - ------- - dict - The popped item. - - Raises - ------- - IndexError - If batch is empty or index is out of range. - """ - - return self._items.pop(index) - - @abstractmethod - def add(self, *args, **kwargs): # type: ignore - """Add objects to BatchRequest.""" - - @abstractmethod - def get_request_body(self) -> Union[List[Dict[str, Any]], Dict[str, Any]]: - """Return the request body to be digested by weaviate that contains all batch items.""" - - @abstractmethod - def add_failed_objects_from_response( - self, - response_item: BatchResponse, - errors_to_exclude: Optional[List[str]], - errors_to_include: Optional[List[str]], - ) -> BatchResponse: - """Add failed items from a weaviate response. - - Parameters - ---------- - response_item : BatchResponse - Weaviate response that contains the status for all objects. - errors_to_exclude : Optional[List[str]] - Which errors should NOT be retried. - errors_to_include : Optional[List[str]] - Which errors should be retried. - - Returns - ------ - BatchResponse: Contains responses form all successful object, eg. those that have not been added to this batch. - """ - - @staticmethod - def _skip_objects_retry( - entry: Dict[str, Any], - errors_to_exclude: Optional[List[str]], - errors_to_include: Optional[List[str]], - ) -> bool: - if ( - len(entry["result"]) == 0 - or "errors" not in entry["result"] - or "error" not in entry["result"]["errors"] - or len(entry["result"]["errors"]["error"]) == 0 - ): - return True - - # skip based on error messages - if errors_to_exclude is not None: - for err in entry["result"]["errors"]["error"]: - if any(excl in err["message"] for excl in errors_to_exclude): - return True - return False - elif errors_to_include is not None: - for err in entry["result"]["errors"]["error"]: - if any(incl in err["message"] for incl in errors_to_include): - return False - return True - return False - - -class ReferenceBatchRequest(BatchRequest): - """ - Collect Weaviate-object references to add them in one request to Weaviate. - Caution this request will miss some validations to be faster. - """ - - def add( # pyright: ignore reportIncompatibleMethodOverride - self, - from_object_class_name: str, - from_object_uuid: UUID, - from_property_name: str, - to_object_uuid: UUID, - to_object_class_name: Optional[str] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Add one Weaviate-object reference to this batch. Does NOT validate the consistency of the - reference against the class schema. Checks the arguments' type and UUIDs' format. - - Parameters - ---------- - from_object_class_name : str - The name of the class that should reference another object. - from_object_uuid : str - The UUID or URL of the object that should reference another object. - from_property_name : str - The name of the property that contains the reference. - to_object_uuid : str - The UUID or URL of the object that is actually referenced. - to_object_class_name : Optional[str], optional - The referenced object class name to which to add the reference (with UUID - `to_object_uuid`), it is included in Weaviate 1.14.0, where all objects are namespaced - by class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - - Raises - ------ - TypeError - If arguments are not of type str. - ValueError - If 'uuid' is not valid or cannot be extracted. - """ - - if not isinstance(from_object_class_name, str): - raise TypeError("'from_object_class_name' argument must be of type str") - - if not isinstance(from_property_name, str): - raise TypeError("'from_property_name' argument must be of type str") - - if to_object_class_name is not None and not isinstance(to_object_class_name, str): - raise TypeError("'to_object_class_name' argument must be of type str") - - to_object_uuid = get_valid_uuid(to_object_uuid) - from_object_uuid = get_valid_uuid(from_object_uuid) - - if to_object_class_name is not None: - to_beacon = f"weaviate://localhost/{to_object_class_name}/{to_object_uuid}" - else: - to_beacon = f"weaviate://localhost/{to_object_uuid}" - - item = { - "from": "weaviate://localhost/" - + from_object_class_name - + "/" - + from_object_uuid - + "/" - + from_property_name, - "to": to_beacon, - } - - if tenant is not None: - item["tenant"] = tenant - - self._items.append(item) - - def get_request_body(self) -> List[Dict[str, Any]]: - """ - Get request body as a list of dictionaries, where each dictionary - is a Weaviate-object reference. - - Returns - ------- - List[dict] - A list of Weaviate-objects references as dictionaries. - """ - - return self._items - - def add_failed_objects_from_response( # pyright: ignore reportIncompatibleMethodOverride - self, - response: BatchResponse, - errors_to_exclude: Optional[List[str]], - errors_to_include: Optional[List[str]], - ) -> BatchResponse: - successful_responses = [] - - for ref in response: - if self._skip_objects_retry(ref, errors_to_exclude, errors_to_include): - successful_responses.append(ref) - continue - self._items.append({"from": ref["from"], "to": ref["to"]}) - return successful_responses - - -class ObjectsBatchRequest(BatchRequest): - """ - Collect objects for one batch request to weaviate. - Caution this batch will not be validated through weaviate. - """ - - def add( # pyright: ignore reportIncompatibleMethodOverride - self, - data_object: dict, - class_name: str, - uuid: Optional[UUID] = None, - vector: Optional[Sequence] = None, - tenant: Optional[str] = None, - ) -> str: - """ - Add one object to this batch. Does NOT validate the consistency of the object against - the client's schema. Checks the arguments' type and UUIDs' format. - - Parameters - ---------- - class_name : str - The name of the class this object belongs to. - data_object : dict - Object to be added as a dict datatype. - uuid : str or None, optional - UUID of the object as a string, by default None - vector: Sequence or None, optional - The embedding of the object that should be validated. - Can be used when: - - a class does not have a vectorization module. - - The given vector was generated using the _identical_ vectorization module that is configured for the - class. In this case this vector takes precedence. - - Supported types are `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, - by default None. - tenant: str, optional - Tenant of the object - - Returns - ------- - str - The UUID of the added object. If one was not provided a UUIDv3 will be generated. - - Raises - ------ - TypeError - If an argument passed is not of an appropriate type. - ValueError - If 'uuid' is not of a proper form. - """ - - if not isinstance(data_object, dict): - raise TypeError("Object must be of type dict") - if not isinstance(class_name, str): - raise TypeError("Class name must be of type str") - - batch_item = {"class": class_name, "properties": copy.deepcopy(data_object)} - if uuid is not None: - valid_uuid = get_valid_uuid(uuid) - else: - valid_uuid = get_valid_uuid(uuid4()) - batch_item["id"] = valid_uuid - - if vector is not None: - batch_item["vector"] = get_vector(vector) - if tenant is not None: - batch_item["tenant"] = tenant - - self._items.append(batch_item) - - return valid_uuid - - def get_request_body(self) -> Dict[str, Any]: - """ - Get the request body as it is needed for the Weaviate server. - - Returns - ------- - dict - The request body as a dict. - """ - - return {"fields": ["ALL"], "objects": self._items} - - def add_failed_objects_from_response( # pyright: ignore reportIncompatibleMethodOverride - self, - response: BatchResponse, - errors_to_exclude: Optional[List[str]], - errors_to_include: Optional[List[str]], - ) -> BatchResponse: - successful_responses = [] - - for obj in response: - if self._skip_objects_retry(obj, errors_to_exclude, errors_to_include): - successful_responses.append(obj) - continue - self.add( - data_object=obj["properties"], - class_name=obj["class"], - uuid=obj["id"], - vector=obj.get("vector", None), - tenant=obj.get("tenant", None), - ) - return successful_responses diff --git a/weaviate/classification/__init__.py b/weaviate/classification/__init__.py deleted file mode 100644 index 2286ffb11..000000000 --- a/weaviate/classification/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Module for classifying objects within Weaviate. -""" - -__all__ = ["Classification", "ConfigBuilder"] - -from .classification import Classification, ConfigBuilder diff --git a/weaviate/classification/classification.py b/weaviate/classification/classification.py deleted file mode 100644 index 435b1699e..000000000 --- a/weaviate/classification/classification.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Classification class definition. -""" - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from weaviate.util import get_valid_uuid, _decode_json_response_dict -from .config_builder import ConfigBuilder - - -class Classification: - """ - Classification class used to schedule and/or check the status of - a classification process of Weaviate objects. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Classification class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - """ - - self._connection = connection - - def schedule(self) -> ConfigBuilder: - """ - Schedule a Classification of the Objects within Weaviate. - - Returns - ------- - weaviate.classification.config_builder.ConfigBuilder - A ConfigBuilder that should be configured to the desired - classification task - """ - - return ConfigBuilder(self._connection, self) - - def get(self, classification_uuid: str) -> dict: - """ - Polls the current state of the given classification. - - Parameters - ---------- - classification_uuid : str - Identifier of the classification. - - Returns - ------- - dict - A dict containing the Weaviate answer. - - Raises - ------ - ValueError - If not a proper uuid. - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - - classification_uuid = get_valid_uuid(classification_uuid) - - try: - response = self._connection.get( - path="/classifications/" + classification_uuid, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Classification status could not be retrieved." - ) from conn_err - - res = _decode_json_response_dict(response, "Get classification status") - assert res is not None - return res - - def is_complete(self, classification_uuid: str) -> bool: - """ - Checks if a started classification job has completed. - - Parameters - ---------- - classification_uuid : str - Identifier of the classification. - - Returns - ------- - bool - True if given classification has finished, False otherwise. - """ - - return self._check_status(classification_uuid, "completed") - - def is_failed(self, classification_uuid: str) -> bool: - """ - Checks if a started classification job has failed. - - Parameters - ---------- - classification_uuid : str - Identifier of the classification. - - Returns - ------- - bool - True if the classification failed, False otherwise. - """ - - return self._check_status(classification_uuid, "failed") - - def is_running(self, classification_uuid: str) -> bool: - """ - Checks if a started classification job is running. - - Parameters - ---------- - classification_uuid : str - Identifier of the classification. - - Returns - ------- - bool - True if the classification is running, False otherwise. - """ - - return self._check_status(classification_uuid, "running") - - def _check_status(self, classification_uuid: str, status: str) -> bool: - """ - Check for a status of a classification. - - Parameters - ---------- - classification_uuid : str - Identifier of the classification. - status : str - Status to check for. - - Returns - ------- - bool - True if 'status' is satisfied, False otherwise. - """ - - try: - response = self.get(classification_uuid) - except RequestsConnectionError: - return False - if response["status"] == status: - return True - return False diff --git a/weaviate/classification/config_builder.py b/weaviate/classification/config_builder.py deleted file mode 100644 index b095e5267..000000000 --- a/weaviate/classification/config_builder.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -ConfigBuilder class definition. -""" - -import time -from typing import Dict, Any, cast, TYPE_CHECKING - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.util import _capitalize_first_letter, _decode_json_response_dict - -if TYPE_CHECKING: - from .classification import Classification - - -class ConfigBuilder: - """ - ConfigBuild class that is used to configure a classification process. - """ - - def __init__(self, connection: Connection, classification: "Classification"): - """ - Initialize a ConfigBuilder class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running weaviate instance. - classification : weaviate.classification.Classification - Classification object to be configured using this ConfigBuilder - instance. - """ - - self._connection = connection - self._classification = classification - self._config: Dict[str, Any] = {} - self._wait_for_completion = False - - def with_type(self, classification_type: str) -> "ConfigBuilder": - """ - Set classification type. - - Parameters - ---------- - classification_type : str - Type of the desired classification. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - self._config["type"] = classification_type - return self - - def with_k(self, k: int) -> "ConfigBuilder": - """ - Set k number for the kNN. - - Parameters - ---------- - k : int - Number of objects to use to make a classification guess. - (For kNN) - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - if "settings" not in self._config: - self._config["settings"] = {"k": k} - else: - self._config["settings"]["k"] = k - return self - - def with_class_name(self, class_name: str) -> "ConfigBuilder": - """ - What Object type to classify. - - Parameters - ---------- - class_name : str - Name of the class to be classified. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - self._config["class"] = _capitalize_first_letter(class_name) - return self - - def with_classify_properties(self, classify_properties: list) -> "ConfigBuilder": - """ - Set the classify properties. - - Parameters - ---------- - classify_properties: list - A list of properties to classify. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - self._config["classifyProperties"] = classify_properties - return self - - def with_based_on_properties(self, based_on_properties: list) -> "ConfigBuilder": - """ - Set properties to build the classification on. - - Parameters - ---------- - based_on_properties: list - A list of properties to classify on. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - self._config["basedOnProperties"] = based_on_properties - return self - - def with_source_where_filter(self, where_filter: dict) -> "ConfigBuilder": - """ - Set Source 'where' Filter. - - Parameters - ---------- - where_filter : dict - Filter to use, as a dict. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - if "filters" not in self._config: - self._config["filters"] = {} - self._config["filters"]["sourceWhere"] = where_filter - return self - - def with_training_set_where_filter(self, where_filter: dict) -> "ConfigBuilder": - """ - Set Training set 'where' Filter. - - Parameters - ---------- - where_filter : dict - Filter to use, as a dict. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - if "filters" not in self._config: - self._config["filters"] = {} - self._config["filters"]["trainingSetWhere"] = where_filter - return self - - def with_target_where_filter(self, where_filter: dict) -> "ConfigBuilder": - """ - Set Target 'where' Filter. - - Parameters - ---------- - where_filter : dict - Filter to use, as a dict. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - if "filters" not in self._config: - self._config["filters"] = {} - self._config["filters"]["targetWhere"] = where_filter - return self - - def with_wait_for_completion(self) -> "ConfigBuilder": - """ - Wait for completion. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - self._wait_for_completion = True - return self - - def with_settings(self, settings: dict) -> "ConfigBuilder": - """ - Set settings for the classification. NOTE if you are using 'kNN' - the value 'k' can be set by this method or by 'with_k'. - This method keeps previously set 'settings'. - - Parameters - ---------- - settings: dict - Additional settings to be set/overwritten. - - Returns - ------- - ConfigBuilder - Updated ConfigBuilder. - """ - - if "settings" not in self._config: - self._config["settings"] = settings - else: - for key in settings: - self._config["settings"][key] = settings[key] - return self - - def _validate_config(self) -> None: - """ - Validate the current classification configuration. - - Raises - ------ - ValueError - If a mandatory field is not set. - """ - - required_fields = ["type", "class", "basedOnProperties", "classifyProperties"] - for field in required_fields: - if field not in self._config: - raise ValueError(f"{field} is not set for this classification") - - if "settings" in self._config: - if not isinstance(self._config["settings"], dict): - raise TypeError('"settings" should be of type dict') - - if self._config["type"] == "knn": - if "k" not in self._config.get("settings", []): - raise ValueError("k is not set for this classification") - - def _start(self) -> dict: - """ - Start the classification based on the configuration set. - - Returns - ------- - dict - Classification result. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - Unexpected error. - """ - - try: - response = self._connection.post(path="/classifications", weaviate_object=self._config) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Classification may not started.") from conn_err - if response.status_code == 201: - res = _decode_json_response_dict(response, "Start classification") - assert res is not None - return res - raise UnexpectedStatusCodeException("Start classification", response) - - def do(self) -> dict: - """ - Start the classification. - - Returns - ------- - dict - Classification result. - """ - - self._validate_config() - - response = self._start() - if not self._wait_for_completion: - return response - - # wait for completion - classification_uuid = response["id"] - # print(classification_uuid) - while self._classification.is_running(classification_uuid): - time.sleep(2.0) - return cast(dict, self._classification.get(classification_uuid)) diff --git a/weaviate/client.py b/weaviate/client.py index 742bb7f32..6651f284e 100644 --- a/weaviate/client.py +++ b/weaviate/client.py @@ -3,10 +3,8 @@ """ import asyncio -from typing import Optional, Tuple, Union, Dict, Any +from typing import Optional, Tuple, Union, Any -from httpx import HTTPError as HttpxError -from requests.exceptions import ConnectionError as RequestsConnectionError from typing_extensions import deprecated from weaviate import syncify @@ -14,34 +12,17 @@ from weaviate.backup.sync import _Backup from weaviate.event_loop import _EventLoopSingleton, _EventLoop from .auth import AuthCredentials -from .backup import Backup -from .batch import Batch -from .classification import Classification from .client_base import _WeaviateClientBase -from .cluster import Cluster from .collections.batch.client import _BatchClientWrapper from .collections.cluster import _Cluster, _ClusterAsync from .collections.collections.async_ import _CollectionsAsync from .collections.collections.sync import _Collections -from .config import AdditionalConfig, Config -from .connect import Connection +from .config import AdditionalConfig from .connect.base import ( ConnectionParams, - TIMEOUT_TYPE_RETURN, ) -from .contextionary import Contextionary -from .data import DataObject -from .embedded import EmbeddedOptions, EmbeddedV3 -from .exceptions import ( - UnexpectedStatusCodeError, - WeaviateClosedClientError, - WeaviateConnectionError, -) -from .gql import Query -from .schema import Schema +from .embedded import EmbeddedOptions from .types import NUMBER -from .util import _get_valid_timeout_config, _type_request_response -from .warnings import _Warnings TIMEOUT_TYPE = Union[Tuple[NUMBER, NUMBER], NUMBER] @@ -174,253 +155,26 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None @deprecated( """ -Python client v3 `weaviate.Client(...)` connections and methods are deprecated and will - be removed by 2024-11-30. +Python client v3 `weaviate.Client(...)` has been removed. - Upgrade your code to use Python client v4 `weaviate.WeaviateClient` connections and methods. - - For Python Client v4 usage, see: https://weaviate.io/developers/weaviate/client-libraries/python - - For code migration, see: https://weaviate.io/developers/weaviate/client-libraries/python/v3_v4_migration +Upgrade your code to use Python client v4 `weaviate.WeaviateClient` connections and methods. + - For Python Client v4 usage, see: https://weaviate.io/developers/weaviate/client-libraries/python + - For code migration, see: https://weaviate.io/developers/weaviate/client-libraries/python/v3_v4_migration - If you have to use v3 code, install the v3 client and pin the v3 dependency in your requirements file: `weaviate-client>=3.26.7;<4.0.0`""" +If you have to use v3 code, install the v3 client and pin the v3 dependency in your requirements file: `weaviate-client>=3.26.7;<4.0.0`""" ) class Client: - """ - The v3 Python-native Weaviate Client class that encapsulates Weaviate functionalities in one object. - A Client instance creates all the needed objects to interact with Weaviate, and connects all of - them to the same Weaviate instance. See below the Attributes of the Client instance. For the - per attribute functionality see that attribute's documentation. - - Attributes - ---------- - backup : weaviate.backup.Backup - A Backup object instance connected to the same Weaviate instance as the Client. - batch : weaviate.batch.Batch - A Batch object instance connected to the same Weaviate instance as the Client. - classification : weaviate.classification.Classification - A Classification object instance connected to the same Weaviate instance as the Client. - cluster : weaviate.cluster.Cluster - A Cluster object instance connected to the same Weaviate instance as the Client. - contextionary : weaviate.contextionary.Contextionary - A Contextionary object instance connected to the same Weaviate instance as the Client. - data_object : weaviate.data.DataObject - A DataObject object instance connected to the same Weaviate instance as the Client. - schema : weaviate.schema.Schema - A Schema object instance connected to the same Weaviate instance as the Client. - query : weaviate.gql.Query - A Query object instance connected to the same Weaviate instance as the Client. - """ def __init__( self, - url: Optional[str] = None, - auth_client_secret: Optional[AuthCredentials] = None, - timeout_config: TIMEOUT_TYPE = (10, 60), - proxies: Union[dict, str, None] = None, - trust_env: bool = False, - additional_headers: Optional[dict] = None, - startup_period: Optional[int] = None, - embedded_options: Optional[EmbeddedOptions] = None, - additional_config: Optional[Config] = None, ) -> None: - """Initialize a Client class instance to use when interacting with Weaviate. - - Arguments: - ---------- - url : str or None, optional - The connection string to the REST API of Weaviate. - auth_client_secret : weaviate.AuthCredentials or None, optional - # fmt: off - Authenticate to weaviate by using one of the given authentication modes: - - weaviate.auth.AuthBearerToken to use existing access and (optionally, but recommended) refresh tokens - - weaviate.auth.AuthClientPassword to use username and password for oidc Resource Owner Password flow - - weaviate.auth.AuthClientCredentials to use a client secret for oidc client credential flow - - # fmt: on - timeout_config : tuple(Real, Real) or Real, optional - Set the timeout configuration for all requests to the Weaviate server. It can be a - real number or, a tuple of two real numbers: (connect timeout, read timeout). - If only one real number is passed then both connect and read timeout will be set to - that value, by default (2, 20). - proxies : dict, str or None, optional - Proxies to be used for requests. Are used by both 'requests' and 'aiohttp'. Can be - passed as a dict ('requests' format: - https://docs.python-requests.org/en/stable/user/advanced/#proxies), str (HTTP/HTTPS - protocols are going to use this proxy) or None. - Default None. - trust_env : bool, optional - Whether to read proxies from the ENV variables: (HTTP_PROXY or http_proxy, HTTPS_PROXY - or https_proxy). Default False. - NOTE: 'proxies' has priority over 'trust_env', i.e. if 'proxies' is NOT None, - 'trust_env' is ignored. - additional_headers : dict or None - Additional headers to include in the requests. - Can be used to set OpenAI/HuggingFace keys. OpenAI/HuggingFace key looks like this: - {"X-OpenAI-Api-Key": ""}, {"X-HuggingFace-Api-Key": ""} - by default None - startup_period : int or None - deprecated, has no effect - embedded_options : weaviate.embedded.EmbeddedOptions or None, optional - Create an embedded Weaviate cluster inside the client - - You can pass weaviate.embedded.EmbeddedOptions() with default values - - Take a look at the attributes of weaviate.embedded.EmbeddedOptions to see what is configurable - additional_config: weaviate.Config, optional - Additional and advanced configuration options for weaviate. - - Raises: - ------- - `TypeError` - If arguments are of a wrong data type. - """ - _Warnings.weaviate_v3_client_is_deprecated() + raise ValueError( + """ +Python client v3 `weaviate.Client(...)` has been removed. - config = Config() if additional_config is None else additional_config - url, embedded_db = self.__parse_url_and_embedded_db(url, embedded_options) +Upgrade your code to use Python client v4 `weaviate.WeaviateClient` connections and methods. + - For Python Client v4 usage, see: https://weaviate.io/developers/weaviate/client-libraries/python + - For code migration, see: https://weaviate.io/developers/weaviate/client-libraries/python/v3_v4_migration - self._connection = Connection( - url=url, - auth_client_secret=auth_client_secret, - timeout_config=_get_valid_timeout_config(timeout_config), - proxies=proxies, - trust_env=trust_env, - additional_headers=additional_headers, - startup_period=startup_period, - embedded_db=embedded_db, - grcp_port=config.grpc_port_experimental, - connection_config=config.connection_config, +If you have to use v3 code, install the v3 client and pin the v3 dependency in your requirements file: `weaviate-client>=3.26.7;<4.0.0`""" ) - self.classification = Classification(self._connection) - self.schema = Schema(self._connection) - self.contextionary = Contextionary(self._connection) - self.batch = Batch(self._connection) - self.data_object = DataObject(self._connection) - self.query = Query(self._connection) - self.backup = Backup(self._connection) - self.cluster = Cluster(self._connection) - - def __parse_url_and_embedded_db( - self, url: Optional[str], embedded_options: Optional[EmbeddedOptions] - ) -> Tuple[str, Optional[EmbeddedV3]]: - if embedded_options is None and url is None: - raise TypeError("Either url or embedded options must be present.") - elif embedded_options is not None and url is not None: - raise TypeError( - f"URL is not expected to be set when using embedded_options but URL was {url}" - ) - - if embedded_options is not None: - embedded_db = EmbeddedV3(options=embedded_options) - embedded_db.start() - return f"http://localhost:{embedded_db.options.port}", embedded_db - - if not isinstance(url, str): - raise TypeError(f"URL is expected to be string but is {type(url)}") - return url.strip("/"), None - - @property - def timeout_config(self) -> TIMEOUT_TYPE_RETURN: - """ - Getter/setter for `timeout_config`. - - Parameters - ---------- - timeout_config : tuple(float, float) or float, optional - For Getter only: Set the timeout configuration for all requests to the Weaviate server. - It can be a real number or, a tuple of two real numbers: - (connect timeout, read timeout). - If only one real number is passed then both connect and read timeout will be set to - that value. - - Returns - ------- - Tuple[float, float] - For Getter only: Requests Timeout configuration. - """ - - return self._connection.timeout_config - - @timeout_config.setter - def timeout_config(self, timeout_config: TIMEOUT_TYPE) -> None: - """ - Setter for `timeout_config`. (docstring should be only in the Getter) - """ - - self._connection.timeout_config = _get_valid_timeout_config(timeout_config) - - def __del__(self) -> None: - # in case an exception happens before definition of the client - if hasattr(self, "_connection"): - self._connection.close() - - def is_ready(self) -> bool: - """ - Ping Weaviate's ready state - - Returns: - `bool` - `True` if Weaviate is ready to accept requests, - `False` otherwise. - """ - - try: - response = self._connection.get(path="/.well-known/ready") - if response.status_code == 200: - return True - return False - except ( - HttpxError, - RequestsConnectionError, - UnexpectedStatusCodeError, - WeaviateClosedClientError, - WeaviateConnectionError, - ): - return False - - def is_live(self) -> bool: - """ - Ping Weaviate's live state. - - Returns: - `bool` - `True` if weaviate is live and should not be killed, - `False` otherwise. - """ - - response = self._connection.get(path="/.well-known/live") - if response.status_code == 200: - return True - return False - - def get_meta(self) -> dict: - """ - Get the meta endpoint description of weaviate. - - Returns: - `dict` - The `dict` describing the weaviate configuration. - - Raises: - `weaviate.UnexpectedStatusCodeError` - If Weaviate reports a none OK status. - """ - - return self._connection.get_meta() - - def get_open_id_configuration(self) -> Optional[Dict[str, Any]]: - """ - Get the openid-configuration. - - Returns - `dict` - The configuration or `None` if not configured. - - Raises - `weaviate.UnexpectedStatusCodeError` - If Weaviate reports a none OK status. - """ - - response = self._connection.get(path="/.well-known/openid-configuration") - if response.status_code == 200: - return _type_request_response(response.json()) - if response.status_code == 404: - return None - raise UnexpectedStatusCodeError("Meta endpoint", response) diff --git a/weaviate/client.pyi b/weaviate/client.pyi index 345aa71eb..0a3a75542 100644 --- a/weaviate/client.pyi +++ b/weaviate/client.pyi @@ -2,51 +2,17 @@ Client class definition. """ -import asyncio from typing import Optional, Tuple, Union, Dict, Any -from httpx import HTTPError as HttpxError -from requests.exceptions import ConnectionError as RequestsConnectionError - from weaviate.backup.backup import _BackupAsync from weaviate.backup.sync import _Backup -from weaviate.collections.classes.internal import _GQLEntryReturnType, _RawGQLReturn - -from weaviate.integrations import _Integrations - -from weaviate import syncify -from .auth import AuthCredentials -from .backup import Backup -from .batch import Batch -from .classification import Classification -from .cluster import Cluster +from weaviate.collections.classes.internal import _RawGQLReturn from weaviate.collections.collections.async_ import _CollectionsAsync from weaviate.collections.collections.sync import _Collections from .collections.batch.client import _BatchClientWrapper from .collections.cluster import _Cluster, _ClusterAsync -from .config import AdditionalConfig, Config -from .connect import Connection, ConnectionV4 -from .connect.base import ( - ConnectionParams, - ProtocolParams, - TIMEOUT_TYPE_RETURN, -) -from .connect.v4 import _ExpectedStatusCodes -from .contextionary import Contextionary -from .data import DataObject -from .embedded import EmbeddedV3, EmbeddedV4, EmbeddedOptions -from .exceptions import ( - UnexpectedStatusCodeError, - WeaviateClosedClientError, - WeaviateConnectionError, -) -from .gql import Query -from .schema import Schema -from weaviate.event_loop import _EventLoopSingleton +from .connect import ConnectionV4 from .types import NUMBER -from .util import _decode_json_response_dict, _get_valid_timeout_config, _type_request_response -from .validator import _validate_input, _ValidateArgument -from .warnings import _Warnings TIMEOUT_TYPE = Union[Tuple[NUMBER, NUMBER], NUMBER] @@ -90,30 +56,4 @@ class WeaviateClient(_WeaviateClientInit): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ... class Client: - _connection: Connection - classification: Classification - schema: Schema - contextionary: Contextionary - batch: Batch - data_object: DataObject - query: Query - backup: Backup - cluster: Cluster - @property - def timeout_config(self) -> TIMEOUT_TYPE_RETURN: ... - def is_ready(self) -> bool: ... - def is_live(self) -> bool: ... - def get_meta(self) -> dict: ... - def get_open_id_configuration(self) -> Optional[Dict[str, Any]]: ... - def __init__( - self, - url: Optional[str] = None, - auth_client_secret: Optional[AuthCredentials] = None, - timeout_config: TIMEOUT_TYPE = (10, 60), - proxies: Union[dict, str, None] = None, - trust_env: bool = False, - additional_headers: Optional[dict] = None, - startup_period: Optional[int] = None, - embedded_options: Optional[EmbeddedOptions] = None, - additional_config: Optional[Config] = None, - ) -> None: ... + def __init__(self) -> None: ... diff --git a/weaviate/cluster/__init__.py b/weaviate/cluster/__init__.py index 7213e40c8..6301acf5c 100644 --- a/weaviate/cluster/__init__.py +++ b/weaviate/cluster/__init__.py @@ -1,7 +1,3 @@ """ Module for interacting with Weaviate cluster information """ - -__all__ = ["Cluster"] - -from .cluster import Cluster diff --git a/weaviate/cluster/cluster.py b/weaviate/cluster/cluster.py deleted file mode 100644 index 28bbf1139..000000000 --- a/weaviate/cluster/cluster.py +++ /dev/null @@ -1,81 +0,0 @@ -""" -Cluster class definition. -""" - -from typing import List, Literal, Optional, cast - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.cluster.types import Node -from weaviate.connect import Connection -from weaviate.exceptions import ( - EmptyResponseException, -) -from ..util import _capitalize_first_letter, _decode_json_response_dict - - -class Cluster: - """ - Cluster class used for cluster information - """ - - def __init__(self, connection: Connection): - """ - Initialize a Cluster class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - """ - - self._connection = connection - - def get_nodes_status( - self, - class_name: Optional[str] = None, - output: Optional[Literal["minimal", "verbose"]] = None, - ) -> List[Node]: - """ - Get the nodes status. - - Parameters - ---------- - class_name : Optional[str] - Get the status for the given class. If not given all classes will be included. - output : Optional[str] - Set the desired output verbosity level. Can be [minimal | verbose], defaults to minimal. - - Returns - ------- - list - List of nodes and their respective status. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - weaviate.EmptyResponseException - If the response is empty. - """ - path = "/nodes" - if class_name is not None: - path += "/" + _capitalize_first_letter(class_name) - if output is not None: - path += f"?output={output}" - - try: - response = self._connection.get(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Get nodes status failed due to connection error" - ) from conn_err - - response_typed = _decode_json_response_dict(response, "Nodes status") - assert response_typed is not None - nodes = response_typed.get("nodes") - if nodes is None or nodes == []: - raise EmptyResponseException("Nodes status response returned empty") - return cast(List[Node], nodes) diff --git a/weaviate/collections/aggregations/aggregate.py b/weaviate/collections/aggregations/aggregate.py index 017a5c8c3..8076b96d7 100644 --- a/weaviate/collections/aggregations/aggregate.py +++ b/weaviate/collections/aggregations/aggregate.py @@ -61,7 +61,6 @@ def __init__( def _query(self) -> AggregateBuilder: return AggregateBuilder( self.__name, - self._connection, # type: ignore # not being used since we query manually in _do ) def _to_aggregate_result( diff --git a/weaviate/connect/__init__.py b/weaviate/connect/__init__.py index 5a7fa8f32..25997531d 100644 --- a/weaviate/connect/__init__.py +++ b/weaviate/connect/__init__.py @@ -4,11 +4,9 @@ """ from .base import ConnectionParams, ProtocolParams -from .v3 import Connection from .v4 import ConnectionV4 __all__ = [ - "Connection", "ConnectionParams", "ConnectionV4", "ProtocolParams", diff --git a/weaviate/connect/v3.py b/weaviate/connect/v3.py deleted file mode 100644 index 04bbfd08e..000000000 --- a/weaviate/connect/v3.py +++ /dev/null @@ -1,676 +0,0 @@ -""" -Connection class definition. -""" - -from __future__ import annotations - -import socket -import time -from threading import Thread, Event -from typing import Any, Dict, Optional, Tuple, Union, cast -from urllib.parse import urlparse - -import requests -from authlib.integrations.requests_client import OAuth2Session # type: ignore -from requests.adapters import HTTPAdapter -from requests.exceptions import ConnectionError as RequestsConnectionError, ReadTimeout -from requests.exceptions import HTTPError as RequestsHTTPError -from requests.exceptions import JSONDecodeError - -from weaviate import __version__ as client_version -from weaviate.auth import AuthCredentials, AuthClientCredentials, AuthApiKey -from weaviate.config import ConnectionConfig -from weaviate.connect.authentication import _Auth -from weaviate.embedded import EmbeddedDB -from weaviate.exceptions import ( - AuthenticationFailedException, - WeaviateStartUpError, -) -from weaviate.types import NUMBER -from weaviate.util import ( - _check_positive_num, - is_weaviate_domain, - is_weaviate_too_old, - is_weaviate_client_too_old, - PYPI_PACKAGE_URL, - _decode_json_response_dict, -) -from weaviate.warnings import _Warnings - -from .base import _ConnectionBase, _get_proxies - -import grpc # type: ignore -from weaviate.proto.v1 import weaviate_pb2_grpc - - -JSONPayload = Union[dict, list] -Session = Union[requests.sessions.Session, OAuth2Session] -TIMEOUT_TYPE_RETURN = Tuple[NUMBER, NUMBER] -INIT_CHECK_TIMEOUT = 0.5 - - -class Connection(_ConnectionBase): - """ - Connection class used to communicate to a weaviate instance. - """ - - def __init__( - self, - url: str, - auth_client_secret: Optional[AuthCredentials], - timeout_config: TIMEOUT_TYPE_RETURN, - proxies: Union[dict, str, None], - trust_env: bool, - additional_headers: Optional[Dict[str, Any]], - startup_period: Optional[int], - connection_config: ConnectionConfig, - embedded_db: Optional[EmbeddedDB] = None, - grcp_port: Optional[int] = None, - ): - """ - Initialize a Connection class instance. - - Parameters - ---------- - url : str - URL to a running weaviate instance. - auth_client_secret : weaviate.auth.AuthCredentials, optional - Credentials to authenticate with a weaviate instance. The credentials are not saved within the client and - authentication is done via authentication tokens. - timeout_config : tuple(float, float) or float, optional - Set the timeout configuration for all requests to the Weaviate server. It can be a - float or, a tuple of two floats: (connect timeout, read timeout). - If only one float is passed then both connect and read timeout will be set to - that value. - proxies : dict, str or None, optional - Proxies to be used for requests. Are used by both 'requests' and 'aiohttp'. Can be - passed as a dict ('requests' format: - https://docs.python-requests.org/en/stable/user/advanced/#proxies), str (HTTP/HTTPS - protocols are going to use this proxy) or None. - trust_env : bool, optional - Whether to read proxies from the ENV variables: (HTTP_PROXY or http_proxy, HTTPS_PROXY - or https_proxy). - NOTE: 'proxies' has priority over 'trust_env', i.e. if 'proxies' is NOT None, - 'trust_env' is ignored. - additional_headers : Dict[str, Any] or None - Additional headers to include in the requests, used to set OpenAI key. OpenAI key looks - like this: {'X-OpenAI-Api-Key': 'KEY'}. - startup_period : int or None - How long the client will wait for weaviate to start before raising a RequestsConnectionError. - If None the client will not wait at all. - - Raises - ------ - ValueError - If no authentication credentials provided but the Weaviate server has an OpenID - configured. - """ - - self._api_version_path = "/v1" - self.url = url # e.g. http://localhost:80 - self.timeout_config: TIMEOUT_TYPE_RETURN = timeout_config - self.embedded_db = embedded_db - - self._grpc_stub: Optional[weaviate_pb2_grpc.WeaviateStub] = None - - # create GRPC channel. If weaviate does not support GRPC, fallback to GraphQL is used. - if grcp_port is not None: - parsed_url = urlparse(self.url) - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - s.settimeout(1.0) # we're only pinging the port, 1s is plenty - s.connect((parsed_url.hostname, grcp_port)) - s.shutdown(2) - s.close() - channel = grpc.insecure_channel(f"{parsed_url.hostname}:{grcp_port}") - self._grpc_stub = weaviate_pb2_grpc.WeaviateStub(channel) - except ( - ConnectionRefusedError, - TimeoutError, - socket.timeout, - ): # self._grpc_stub stays None - s.close() - - self._headers = {"content-type": "application/json"} - if additional_headers is not None: - if not isinstance(additional_headers, dict): - raise TypeError( - f"'additional_headers' must be of type dict or None. Given type: {type(additional_headers)}." - ) - self.__additional_headers = additional_headers - for key, value in additional_headers.items(): - self._headers[key.lower()] = value - - self._proxies = _get_proxies(proxies, trust_env) - - # auth secrets can contain more information than a header (refresh tokens and lifetime) and therefore take - # precedent over headers - if "authorization" in self._headers and auth_client_secret is not None: - _Warnings.auth_header_and_auth_secret() - self._headers.pop("authorization") - - # if there are API keys included add them right away to headers - if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey): - self._headers["authorization"] = "Bearer " + auth_client_secret.api_key - - self._session: Session - self._shutdown_background_event: Optional[Event] = None - - if startup_period is not None: - _check_positive_num(startup_period, "startup_period", int, include_zero=False) - self.wait_for_weaviate(startup_period) - - self._create_sessions(auth_client_secret) - self._add_adapter_to_session(connection_config) - - self._server_version = self.get_meta()["version"] - if self._server_version < "1.14": - _Warnings.weaviate_server_older_than_1_14(self._server_version) - if is_weaviate_too_old(self._server_version): - _Warnings.weaviate_too_old_vs_latest(self._server_version) - - try: - pkg_info = requests.get(PYPI_PACKAGE_URL, timeout=INIT_CHECK_TIMEOUT).json() - pkg_info = pkg_info.get("info", {}) - latest_version = pkg_info.get("version", "unknown version") - if is_weaviate_client_too_old(client_version, latest_version): - _Warnings.weaviate_client_too_old_vs_latest(client_version, latest_version) - except requests.exceptions.RequestException: - pass # ignore any errors related to requests, it is a best-effort warning - - if embedded_db is not None: - self.wait_for_weaviate(10) - - def _create_sessions(self, auth_client_secret: Optional[AuthCredentials]) -> None: - """Creates a async httpx session and a sync request session. - - Either through authlib.oauth2 if authentication is enabled or a normal request session otherwise. - - Raises - ------ - ValueError - If no authentication credentials provided but the Weaviate server has OpenID configured. - """ - # API keys are separate from OIDC and do not need any config from weaviate - if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey): - self._session = requests.Session() - return - - if "authorization" in self._headers and auth_client_secret is None: - self._session = requests.Session() - return - - oidc_url = self.url + self._api_version_path + "/.well-known/openid-configuration" - response = requests.get( - oidc_url, - headers=self._headers, - timeout=self._timeout_config, - proxies=self._proxies, - ) - if response.status_code == 200: - # Some setups are behind proxies that return some default page - for example a login - for all requests. - # If the response is not json, we assume that this is the case and try unauthenticated access. Any auth - # header provided by the user is unaffected. - try: - resp = response.json() - except JSONDecodeError: - _Warnings.auth_cannot_parse_oidc_config(oidc_url) - self._session = requests.Session() - return - - if auth_client_secret is not None and not isinstance(auth_client_secret, AuthApiKey): - _auth = _Auth( - session_type=OAuth2Session, - oidc_config=resp, - credentials=auth_client_secret, - connection=self, - ) - self._session = _auth.get_auth_session() - - if isinstance(auth_client_secret, AuthClientCredentials): - # credentials should only be saved for client credentials, otherwise use refresh token - self._create_background_token_refresh(_auth) - else: - self._create_background_token_refresh() - else: - msg = f""""No login credentials provided. The weaviate instance at {self.url} requires login credentials. - - For more information, see: https://weaviate.io/developers/weaviate/client-libraries/python#authentication""" - - if is_weaviate_domain(self.url): - msg += """ - - You can instantiate the client with login credentials for Weaviate Cloud using - - client = weaviate.Client( - url=YOUR_WEAVIATE_URL, - auth_client_secret=weaviate.AuthApiKey( - api_key = YOUR_WCD_API_KEY, - )) - """ - raise AuthenticationFailedException(msg) - elif response.status_code == 404 and auth_client_secret is not None: - _Warnings.auth_with_anon_weaviate() - self._session = requests.Session() - else: - self._session = requests.Session() - - def get_current_bearer_token(self) -> str: - if "authorization" in self._headers: - return self._headers["authorization"] - elif isinstance(self._session, OAuth2Session): - return f"Bearer {self._session.token['access_token']}" - - return "" - - def get_proxies(self) -> dict: - return self._proxies - - def _add_adapter_to_session(self, connection_config: ConnectionConfig) -> None: - adapter = HTTPAdapter( - pool_connections=connection_config.session_pool_connections, - pool_maxsize=connection_config.session_pool_maxsize, - ) - self._session.mount("http://", adapter) - self._session.mount("https://", adapter) - - def _create_background_token_refresh(self, _auth: Optional[_Auth] = None) -> None: - """Create a background thread that periodically refreshes access and refresh tokens. - - While the underlying library refreshes tokens, it does not have an internal cronjob that checks every - X-seconds if a token has expired. If there is no activity for longer than the refresh tokens lifetime, it will - expire. Therefore, refresh manually shortly before expiration time is up.""" - assert isinstance(self._session, OAuth2Session) - if "refresh_token" not in self._session.token and _auth is None: - return - - expires_in: int = self._session.token.get( - "expires_in", 60 - ) # use 1minute as token lifetime if not supplied - self._shutdown_background_event = Event() - - def periodic_refresh_token( - refresh_time: int, _auth: Optional[_Auth[OAuth2Session]] - ) -> None: - time.sleep(max(refresh_time - 30, 1)) - while ( - self._shutdown_background_event is not None - and not self._shutdown_background_event.is_set() - ): - # use refresh token when available - try: - if "refresh_token" in cast(OAuth2Session, self._session).token: - assert isinstance(self._session, OAuth2Session) - self._session.token = self._session.refresh_token( - self._session.metadata["token_endpoint"] - ) - refresh_time = ( - int(self._session.token.get("expires_in")) - 30 # pyright: ignore - ) - else: - # client credentials usually does not contain a refresh token => get a new token using the - # saved credentials - assert _auth is not None - new_session = _auth.get_auth_session() - self._session.token = new_session.fetch_token() # type: ignore - except (RequestsHTTPError, ReadTimeout) as exc: - # retry again after one second, might be an unstable connection - refresh_time = 1 - _Warnings.token_refresh_failed(exc) - - time.sleep(max(refresh_time, 1)) - - demon = Thread( - target=periodic_refresh_token, - args=(expires_in, _auth), - daemon=True, - name="TokenRefresh", - ) - demon.start() - - def close(self) -> None: - """Shutdown connection class gracefully.""" - # in case an exception happens before definition of these members - if ( - hasattr(self, "_shutdown_background_event") - and self._shutdown_background_event is not None - ): - self._shutdown_background_event.set() - if hasattr(self, "_session"): - self._session.close() - - def delete( - self, - path: str, - weaviate_object: Optional[JSONPayload] = None, - params: Optional[Dict[str, Any]] = None, - ) -> requests.Response: - """ - Make a DELETE request to the Weaviate server instance. - - Parameters - ---------- - path : str - Sub-path to the Weaviate resources. Must be a valid Weaviate sub-path. - e.g. '/meta' or '/objects', without version. - weaviate_object : dict, optional - Object is used as payload for DELETE request. By default None. - params : dict, optional - Additional request parameters, by default None - - Returns - ------- - requests.Response - The response, if request was successful. - - Raises - ------ - requests.ConnectionError - If the DELETE request could not be made. - """ - if self.embedded_db is not None: - self.embedded_db.ensure_running() - request_url = self.url + self._api_version_path + path - - return self._session.delete( - url=request_url, - json=weaviate_object, - headers=self._headers, - timeout=self._timeout_config, - proxies=self._proxies, - params=params, - ) - - def patch( - self, - path: str, - weaviate_object: JSONPayload, - params: Optional[Dict[str, Any]] = None, - ) -> requests.Response: - """ - Make a PATCH request to the Weaviate server instance. - - Parameters - ---------- - path : str - Sub-path to the Weaviate resources. Must be a valid Weaviate sub-path. - e.g. '/meta' or '/objects', without version. - weaviate_object : dict - Object is used as payload for PATCH request. - params : dict, optional - Additional request parameters, by default None - Returns - ------- - requests.Response - The response, if request was successful. - - Raises - ------ - requests.ConnectionError - If the PATCH request could not be made. - """ - if self.embedded_db is not None: - self.embedded_db.ensure_running() - request_url = self.url + self._api_version_path + path - - return self._session.patch( - url=request_url, - json=weaviate_object, - headers=self._headers, - timeout=self._timeout_config, - proxies=self._proxies, - params=params, - ) - - def post( - self, - path: str, - weaviate_object: JSONPayload, - params: Optional[Dict[str, Any]] = None, - ) -> requests.Response: - """ - Make a POST request to the Weaviate server instance. - - Parameters - ---------- - path : str - Sub-path to the Weaviate resources. Must be a valid Weaviate sub-path. - e.g. '/meta' or '/objects', without version. - weaviate_object : dict - Object is used as payload for POST request. - params : dict, optional - Additional request parameters, by default None - external_url: Is an external (non-weaviate) url called - - Returns - ------- - requests.Response - The response, if request was successful. - - Raises - ------ - requests.ConnectionError - If the POST request could not be made. - """ - if self.embedded_db is not None: - self.embedded_db.ensure_running() - request_url = self.url + self._api_version_path + path - - return self._session.post( - url=request_url, - json=weaviate_object, - headers=self._headers, - timeout=self._timeout_config, - proxies=self._proxies, - params=params, - ) - - def put( - self, - path: str, - weaviate_object: JSONPayload, - params: Optional[Dict[str, Any]] = None, - ) -> requests.Response: - """ - Make a PUT request to the Weaviate server instance. - - Parameters - ---------- - path : str - Sub-path to the Weaviate resources. Must be a valid Weaviate sub-path. - e.g. '/meta' or '/objects', without version. - weaviate_object : dict - Object is used as payload for PUT request. - params : dict, optional - Additional request parameters, by default None - Returns - ------- - requests.Response - The response, if request was successful. - - Raises - ------ - requests.ConnectionError - If the PUT request could not be made. - """ - if self.embedded_db is not None: - self.embedded_db.ensure_running() - request_url = self.url + self._api_version_path + path - - return self._session.put( - url=request_url, - json=weaviate_object, - headers=self._headers, - timeout=self._timeout_config, - proxies=self._proxies, - params=params, - ) - - def get( - self, path: str, params: Optional[Dict[str, Any]] = None, external_url: bool = False - ) -> requests.Response: - """Make a GET request. - - Parameters - ---------- - path : str - Sub-path to the Weaviate resources. Must be a valid Weaviate sub-path. - e.g. '/meta' or '/objects', without version. - params : dict, optional - Additional request parameters, by default None - external_url: Is an external (non-weaviate) url called - - Returns - ------- - requests.Response - The response if request was successful. - - Raises - ------ - requests.ConnectionError - If the GET request could not be made. - """ - if self.embedded_db is not None: - self.embedded_db.ensure_running() - if params is None: - params = {} - - if external_url: - request_url = path - else: - request_url = self.url + self._api_version_path + path - - return self._session.get( - url=request_url, - headers=self._headers, - timeout=self._timeout_config, - params=params, - proxies=self._proxies, - ) - - def head( - self, - path: str, - params: Optional[Dict[str, Any]] = None, - ) -> requests.Response: - """ - Make a HEAD request to the server. - - Parameters - ---------- - path : str - Sub-path to the resources. Must be a valid sub-path. - e.g. '/meta' or '/objects', without version. - params : dict, optional - Additional request parameters, by default None - - Returns - ------- - requests.Response - The response to the request. - - Raises - ------ - requests.ConnectionError - If the HEAD request could not be made. - """ - if self.embedded_db is not None: - self.embedded_db.ensure_running() - request_url = self.url + self._api_version_path + path - - return self._session.head( - url=request_url, - headers=self._headers, - timeout=self._timeout_config, - proxies=self._proxies, - params=params, - ) - - @property - def timeout_config(self) -> TIMEOUT_TYPE_RETURN: # pyright: ignore - """ - Getter/setter for `timeout_config`. - - Parameters - ---------- - timeout_config : tuple(float, float), optional - For Setter only: Set the timeout configuration for all requests to the Weaviate server. - It can be a float or, a tuple of two floats: - (connect timeout, read timeout). - If only one float is passed then both connect and read timeout will be set to - that value. - - Returns - ------- - Tuple[float, float] - For Getter only: Requests Timeout configuration. - """ - - return self._timeout_config - - @timeout_config.setter - def timeout_config(self, timeout_config: TIMEOUT_TYPE_RETURN) -> None: # pyright: ignore - """ - Setter for `timeout_config`. (docstring should be only in the Getter) - """ - - self._timeout_config = timeout_config - - @property - def proxies(self) -> dict: - return self._proxies - - def wait_for_weaviate(self, startup_period: int) -> None: - """ - Waits until weaviate is ready or the time limit given in 'startup_period' has passed. - - Parameters - ---------- - startup_period : int - Describes how long the client will wait for weaviate to start in seconds. - - Raises - ------ - WeaviateStartUpError - If weaviate takes longer than the time limit to respond. - """ - - ready_url = self.url + self._api_version_path + "/.well-known/ready" - for _i in range(startup_period): - try: - requests.get( - ready_url, headers=self._headers, timeout=INIT_CHECK_TIMEOUT - ).raise_for_status() - return - except (RequestsHTTPError, RequestsConnectionError, ReadTimeout): - time.sleep(1) - - try: - requests.get( - ready_url, headers=self._headers, timeout=INIT_CHECK_TIMEOUT - ).raise_for_status() - return - except (RequestsHTTPError, RequestsConnectionError, ReadTimeout) as error: - raise WeaviateStartUpError( - f"Weaviate did not start up in {startup_period} seconds. Either the Weaviate URL {self.url} is wrong or Weaviate did not start up in the interval given in 'startup_period'." - ) from error - - @property - def grpc_stub(self) -> Optional[weaviate_pb2_grpc.WeaviateStub]: - return self._grpc_stub - - @property - def server_version(self) -> str: - """ - Version of the weaviate instance. - """ - return self._server_version - - def get_meta(self) -> Dict[str, str]: - """ - Returns the meta endpoint. - """ - response = self.get(path="/meta") - res = _decode_json_response_dict(response, "Meta endpoint") - assert res is not None - return res diff --git a/weaviate/contextionary/__init__.py b/weaviate/contextionary/__init__.py deleted file mode 100644 index f6847ef1f..000000000 --- a/weaviate/contextionary/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Contextionary module used to interact with Weaviate's contextionary module. -""" - -__all__ = ["Contextionary"] - -from .crud_contextionary import Contextionary diff --git a/weaviate/contextionary/crud_contextionary.py b/weaviate/contextionary/crud_contextionary.py deleted file mode 100644 index b76fd62f4..000000000 --- a/weaviate/contextionary/crud_contextionary.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -Contextionary class definition. -""" - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.util import _decode_json_response_dict - - -class Contextionary: - """ - Contextionary class used to add extend the Weaviate contextionary module - or to get vector/s of a specific concept. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Contextionary class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - """ - - self._connection = connection - - def extend(self, concept: str, definition: str, weight: float = 1.0) -> None: - """ - Extend the text2vec-contextionary with new concepts - - Parameters - ---------- - concept : str - The new concept that should be added that is not in the Weaviate - or needs to be updated, e.g. an abbreviation. - definition : str - The definition of the new concept. - weight : float, optional - The weight of the new definition compared to the old one, - must be in-between the interval [0.0; 1.0], by default 1.0 - - Examples - -------- - >>> client.contextionary.extend( - ... concept = 'palantir', - ... definition = 'spherical stone objects used for communication in Middle-earth' - ... ) - - - Raises - ------ - TypeError - If an argument is not of an appropriate type. - ValueError - If 'weight' is outside the interval [0.0; 1.0]. - requests.ConnectionError - If text2vec-contextionary could not be extended. - weaviate.UnexpectedStatusCodeException - If the network connection to weaviate fails. - """ - - if not isinstance(concept, str): - raise TypeError("Concept must be string") - if not isinstance(definition, str): - raise TypeError("Definition must be string") - if not isinstance(weight, float): - raise TypeError("Weight must be float") - - if weight > 1.0 or weight < 0.0: - raise ValueError("Weight out of limits 0.0 <= weight <= 1.0") - - extension = {"concept": concept, "definition": definition, "weight": weight} - - try: - response = self._connection.post( - path="/modules/text2vec-contextionary/extensions", - weaviate_object=extension, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "text2vec-contextionary could not be extended." - ) from conn_err - if response.status_code == 200: - # Successfully extended - return - raise UnexpectedStatusCodeException("Extend text2vec-contextionary", response) - - def get_concept_vector(self, concept: str) -> dict: - """ - Retrieves the vector representation of the given concept. - - Parameters - ---------- - concept : str - Concept for which the vector should be retrieved. - May be camelCase for word combinations. - - Examples - -------- - >>> client.contextionary.get_concept_vector('king') - { - "individualWords": [ - { - "info": { - "nearestNeighbors": [ - { - "word": "king" - }, - { - "distance": 5.7498446, - "word": "kings" - }, - ..., - { - "distance": 6.1396513, - "word": "queen" - } - ], - "vector": [ - -0.68988, - ..., - -0.561865 - ] - }, - "present": true, - "word": "king" - } - ] - } - - Returns - ------- - dict - A dictionary containing info and the vector/s of the concept. - The vector might be empty if the text2vec-contextionary does not contain it. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - - path = "/modules/text2vec-contextionary/concepts/" + concept - try: - response = self._connection.get(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "text2vec-contextionary vector was not retrieved." - ) from conn_err - else: - res = _decode_json_response_dict(response, "text2vec-contextionary vector") - assert res is not None - return res diff --git a/weaviate/data/__init__.py b/weaviate/data/__init__.py deleted file mode 100644 index 020d21ca9..000000000 --- a/weaviate/data/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -Data module used to create, read, update and delete object and references. -""" - -__all__ = ["DataObject", "ConsistencyLevel"] - -from .crud_data import DataObject -from .replication import ConsistencyLevel diff --git a/weaviate/data/crud_data.py b/weaviate/data/crud_data.py deleted file mode 100644 index 74f4fd84d..000000000 --- a/weaviate/data/crud_data.py +++ /dev/null @@ -1,1006 +0,0 @@ -""" -DataObject class definition. -""" - -import uuid as uuid_lib -import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from weaviate.data.references import Reference -from weaviate.data.replication import ConsistencyLevel -from weaviate.error_msgs import DATA_DEPRECATION_NEW_V14_CLS_NS_W, DATA_DEPRECATION_OLD_V14_CLS_NS_W -from weaviate.exceptions import ( - ObjectAlreadyExistsException, - UnexpectedStatusCodeException, -) -from weaviate.util import ( - _get_dict_from_object, - get_vector, - get_valid_uuid, - _capitalize_first_letter, - _check_positive_num, -) -from weaviate.types import UUID - - -class DataObject: - """ - DataObject class used to manipulate object to/from Weaviate. - - Attributes - ---------- - reference : weaviate.data.references.Reference - A Reference object to create objects cross-references. - """ - - def __init__(self, connection: Connection): - """ - Initialize a DataObject class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - """ - - self._connection = connection - self.reference = Reference(self._connection) - - def create( - self, - data_object: Union[dict, str], - class_name: str, - uuid: Union[str, uuid_lib.UUID, None] = None, - vector: Optional[Sequence] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> str: - """ - Takes a dict describing the object and adds it to Weaviate. - - Parameters - ---------- - data_object : dict or str - Object to be added. - If type is str it should be either a URL or a file. - class_name : str - Class name associated with the object given. - uuid : str, uuid.UUID or None, optional - Object will be created under this uuid if it is provided. - Otherwise, Weaviate will generate a uuid for this object, - by default None. - vector: Sequence or None, optional - Embedding for the object. - Can be used when: - - a class does not have a vectorization module. - - The given vector was generated using the _identical_ vectorization module that is configured for the - class. In this case this vector takes precedence. - - Supported types are `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, - by default None. - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: Optional[str], optional - The name of the tenant for which this operation is being performed. - - Examples - -------- - Schema contains a class Author with only 'name' and 'age' primitive property. - - >>> client.data_object.create( - ... data_object = {'name': 'Neil Gaiman', 'age': 60}, - ... class_name = 'Author', - ... ) - '46091506-e3a0-41a4-9597-10e3064d8e2d' - >>> client.data_object.create( - ... data_object = {'name': 'Andrzej Sapkowski', 'age': 72}, - ... class_name = 'Author', - ... uuid = 'e067f671-1202-42c6-848b-ff4d1eb804ab' - ... ) - 'e067f671-1202-42c6-848b-ff4d1eb804ab' - - Returns - ------- - str - Returns the UUID of the created object if successful. - - Raises - ------ - TypeError - If argument is of wrong type. - ValueError - If argument contains an invalid value. - weaviate.ObjectAlreadyExistsException - If an object with the given uuid already exists within Weaviate. - weaviate.UnexpectedStatusCodeException - If creating the object in Weaviate failed for a different reason, - more information is given in the exception. - requests.ConnectionError - If the network connection to Weaviate fails. - """ - - if not isinstance(class_name, str): - raise TypeError(f"Expected class_name of type str but was: {type(class_name)}") - loaded_data_object = _get_dict_from_object(data_object) - - weaviate_obj = { - "class": _capitalize_first_letter(class_name), - "properties": loaded_data_object, - } - if uuid is not None: - weaviate_obj["id"] = get_valid_uuid(uuid) - - if vector is not None: - weaviate_obj["vector"] = get_vector(vector) - - path = "/objects" - params = {} - if consistency_level is not None: - params["consistency_level"] = ConsistencyLevel(consistency_level).value - if tenant is not None: - weaviate_obj["tenant"] = tenant - try: - response = self._connection.post(path=path, weaviate_object=weaviate_obj, params=params) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Object was not added to Weaviate.") from conn_err - if response.status_code == 200: - return str(response.json()["id"]) - - object_does_already_exist = False - try: - if "already exists" in response.json()["error"][0]["message"]: - object_does_already_exist = True - except KeyError: - pass - if object_does_already_exist: - raise ObjectAlreadyExistsException(str(uuid)) - raise UnexpectedStatusCodeException("Creating object", response) - - def update( - self, - data_object: Union[dict, str], - class_name: str, - uuid: Union[str, uuid_lib.UUID], - vector: Optional[Sequence] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Update an already existing object in Weaviate with the given data object. - Overwrites only the specified fields, the unspecified ones remain unchanged. - - Parameters - ---------- - data_object : dict or str - The object states the fields that should be updated. - Fields not specified in the 'data_object' remain unchanged. - Fields that are None will not be changed. - If type is str it should be either an URL or a file. - class_name : str - The class name of the object. - uuid : str or uuid.UUID - The ID of the object that should be changed. - vector: Sequence or None, optional - Embedding for the object. - Can be used when: - - a class does not have a vectorization module. - - The given vector was generated using the _identical_ vectorization module that is configured for the - class. In this case this vector takes precedence. - - Supported types are `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, - by default None. - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: Optional[str], optional - The name of the tenant for which this operation is being performed. - - Examples - -------- - >>> author_id = client.data_object.create( - ... data_object = {'name': 'Philip Pullman', 'age': 64}, - ... class_name = 'Author' - ... ) - >>> client.data_object.get(author_id) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617111215172, - "id": "bec2bca7-264f-452a-a5bb-427eb4add068", - "lastUpdateTimeUnix": 1617111215172, - "properties": { - "age": 64, - "name": "Philip Pullman" - }, - "vectorWeights": null - } - >>> client.data_object.update( - ... data_object = {'age': 74}, - ... class_name = 'Author', - ... uuid = author_id - ... ) - >>> client.data_object.get(author_id) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617111215172, - "id": "bec2bca7-264f-452a-a5bb-427eb4add068", - "lastUpdateTimeUnix": 1617111215172, - "properties": { - "age": 74, - "name": "Philip Pullman" - }, - "vectorWeights": null - } - - Raises - ------ - TypeError - If argument is of wrong type. - ValueError - If argument contains an invalid value. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a none successful status. - """ - params = {} - if consistency_level is not None: - params["consistency_level"] = ConsistencyLevel(consistency_level).value - weaviate_obj, path = self._create_object_for_update(data_object, class_name, uuid, vector) - if tenant is not None: - weaviate_obj["tenant"] = tenant - - try: - response = self._connection.patch( - path=path, - weaviate_object=weaviate_obj, - params=params, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Object was not updated.") from conn_err - if response.status_code == 204: - # Successful merge - return - raise UnexpectedStatusCodeException("Update of the object not successful", response) - - def replace( - self, - data_object: Union[dict, str], - class_name: str, - uuid: Union[str, uuid_lib.UUID], - vector: Optional[Sequence] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Replace an already existing object with the given data object. - This method replaces the whole object. - - Parameters - ---------- - data_object : dict or str - Describes the new values. It may be an URL or path to a json - or a python dict describing the new values. - class_name : str - Name of the class of the object that should be updated. - uuid : str or uuid.UUID - The UUID of the object that should be changed. - vector: Sequence or None, optional - Embedding for the object. - Can be used when: - - a class does not have a vectorization module. - - The given vector was generated using the _identical_ vectorization module that is configured for the - class. In this case this vector takes precedence. - - Supported types are `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, - by default None. - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: Optional[str], optional - The name of the tenant for which this operation is being performed. - - Examples - -------- - >>> author_id = client.data_object.create( - ... data_object = {'name': 'H. Lovecraft', 'age': 46}, - ... class_name = 'Author' - ... ) - >>> client.data_object.get(author_id) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617112817487, - "id": "d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - "lastUpdateTimeUnix": 1617112817487, - "properties": { - "age": 46, - "name": "H. Lovecraft" - }, - "vectorWeights": null - } - >>> client.data_object.replace( - ... data_object = {'name': 'H.P. Lovecraft'}, - ... class_name = 'Author', - ... uuid = author_id - ... ) - >>> client.data_object.get(author_id) - { - "additional": {}, - "class": "Author", - "id": "d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - "lastUpdateTimeUnix": 1617112838668, - "properties": { - "name": "H.P. Lovecraft" - }, - "vectorWeights": null - } - - Raises - ------ - TypeError - If argument is of wrong type. - ValueError - If argument contains an invalid value. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a none OK status. - """ - params = {} - if consistency_level is not None: - params["consistency_level"] = ConsistencyLevel(consistency_level).value - weaviate_obj, path = self._create_object_for_update(data_object, class_name, uuid, vector) - if tenant is not None: - weaviate_obj["tenant"] = tenant - try: - response = self._connection.put(path=path, weaviate_object=weaviate_obj, params=params) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Object was not replaced.") from conn_err - if response.status_code == 200: - # Successful update - return - raise UnexpectedStatusCodeException("Replace object", response) - - def _create_object_for_update( - self, - data_object: Union[dict, str], - class_name: str, - uuid: Union[str, uuid_lib.UUID], - vector: Optional[Sequence] = None, - ) -> Tuple[Dict[str, Any], str]: - if not isinstance(class_name, str): - raise TypeError("Class must be type str") - - uuid = get_valid_uuid(uuid) - - object_dict = _get_dict_from_object(data_object) - - weaviate_obj = { - "id": uuid, - "properties": object_dict, - "class": _capitalize_first_letter(class_name), - } - - if vector is not None: - weaviate_obj["vector"] = get_vector(vector) - - is_server_version_14 = self._connection.server_version >= "1.14" - - if is_server_version_14: - path = f"/objects/{_capitalize_first_letter(class_name)}/{uuid}" - else: - path = f"/objects/{uuid}" - return weaviate_obj, path - - def get_by_id( - self, - uuid: Union[str, uuid_lib.UUID], - additional_properties: Optional[List[str]] = None, - with_vector: bool = False, - class_name: Optional[str] = None, - node_name: Optional[str] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> Optional[dict]: - """ - Get an object as dict. - - Parameters - ---------- - uuid : str or uuid.UUID - The identifier of the object that should be retrieved. - additional_properties : list of str, optional - List of additional properties that should be included in the request, - by default None - with_vector: bool - If True the `vector` property will be returned too, - by default False. - class_name : Optional[str], optional - The class name of the object with UUID `uuid`. Introduced in Weaviate version v1.14.0. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - tenant: str, optional - The name of the tenant for which this operation is being performed. - - Examples - -------- - >>> client.data_object.get_by_id( - ... uuid="d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617112817487, - "id": "d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - "lastUpdateTimeUnix": 1617112817487, - "properties": { - "age": 46, - "name": "H.P. Lovecraft" - }, - "vectorWeights": null - } - - Returns - ------- - dict or None - dict in case the object exists. - None in case the object does not exist. - - Raises - ------ - TypeError - If argument is of wrong type. - ValueError - If argument contains an invalid value. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a none OK status. - """ - - return self.get( - uuid=uuid, - additional_properties=additional_properties, - with_vector=with_vector, - class_name=class_name, - node_name=node_name, - consistency_level=consistency_level, - tenant=tenant, - ) - - def get( - self, - uuid: Union[str, uuid_lib.UUID, None] = None, - additional_properties: Optional[List[str]] = None, - with_vector: bool = False, - class_name: Optional[str] = None, - node_name: Optional[str] = None, - consistency_level: Optional[ConsistencyLevel] = None, - limit: Optional[int] = None, - after: Optional[UUID] = None, - offset: Optional[int] = None, - sort: Optional[Dict[str, Union[str, bool, List[bool], List[str]]]] = None, - tenant: Optional[str] = None, - ) -> Optional[Dict[str, Any]]: - """ - Gets objects from Weaviate, the maximum number of objects returned is 100. - If 'uuid' is None, all objects are returned. If 'uuid' is specified the result is the same - as for `get_by_uuid` method. - - Parameters - ---------- - uuid : str, uuid.UUID or None, optional - The identifier of the object that should be retrieved. - additional_properties : list of str, optional - list of additional properties that should be included in the request, - by default None - with_vector : bool - If True the `vector` property will be returned too, - by default False - class_name: Optional[str], optional - The class name of the object with UUID `uuid`. Introduced in Weaviate version v1.14.0. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - a request before it is considered successful. Mutually exclusive with node_name param. - node_name : Optional[str], optional - The name of the target node which should fulfill the request. Mutually exclusive with - consistency_level param. - limit: Optional[int], optional - The maximum number of data objects to return. - by default None, which uses the Weaviate default of 100 entries - after: Optional[UUID], optional - Can be used to extract all elements by giving the last ID from the previous "page". Requires limit to be set - but cannot be combined with any other filters or search. Part of the Cursor API. - offset: Optional[int], optional - The offset of objects returned, i.e. the starting index of the returned objects. Should be - used in conjunction with the 'limit' parameter. - sort: Optional[Dict] - A dictionary for sorting objects. - sort['properties']: str, List[str] - By which properties the returned objects should be sorted. When more than one property is given, the objects are sorted in order of the list. - The order of the sorting can be given by using 'sort['order_asc']'. - sort['order_asc']: bool, List[bool] - The order the properties given in 'sort['properties']' should be returned in. When a single boolean is used, all properties are sorted in the same order. - If a list is used, it needs to have the same length as 'sort'. Each properties order is then decided individually. - If 'sort['order_asc']' is True, the properties are sorted in ascending order. If it is False, they are sorted in descending order. - if 'sort['order_asc']' is not given, all properties are sorted in ascending order. - tenant: Optional[str], optional - The name of the tenant for which this operation is being performed. - - Returns - ------- - list of dicts - A list of all objects. If no objects where found the list is empty. - - Raises - ------ - TypeError - If argument is of wrong type. - ValueError - If argument contains an invalid value. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a none OK status. - """ - is_server_version_14 = self._connection.server_version >= "1.14" - - if class_name is None and is_server_version_14 and uuid is not None: - warnings.warn( - message=DATA_DEPRECATION_NEW_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if class_name is not None and uuid is not None: - if not is_server_version_14: - warnings.warn( - message=DATA_DEPRECATION_OLD_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if not isinstance(class_name, str): - raise TypeError(f"'class_name' must be of type str. Given type: {type(class_name)}") - - params = _get_params(additional_properties, with_vector) - - if class_name and is_server_version_14: - if uuid is not None: - path = f"/objects/{_capitalize_first_letter(class_name)}" - else: - path = "/objects" - params["class"] = _capitalize_first_letter(class_name) - else: - path = "/objects" - - if uuid is not None: - path += "/" + get_valid_uuid(uuid) - - if consistency_level is not None: - params["consistency_level"] = ConsistencyLevel(consistency_level).value - - if tenant is not None: - params["tenant"] = tenant - - if node_name is not None: - params["node_name"] = node_name - - if limit is not None: - _check_positive_num(limit, "limit", int, include_zero=False) - params["limit"] = limit - - if after is not None: - params["after"] = get_valid_uuid(after) - - if offset is not None: - _check_positive_num(offset, "offset", int, include_zero=True) - params["offset"] = offset - - if sort is not None: - if "properties" not in sort: - raise ValueError("The sort clause is missing the required field: 'properties'.") - if "order_asc" not in sort: - sort["order_asc"] = True - if not isinstance(sort, Dict): - raise TypeError(f"'sort' must be of type dict. Given type: {type(sort)}.") - if isinstance(sort["properties"], str): - sort["properties"] = [sort["properties"]] - elif not isinstance(sort["properties"], list) or not all( - isinstance(x, str) for x in sort["properties"] - ): - raise TypeError( - f"'sort['properties']' must be of type str or list[str]. Given type: {type(sort['properties'])}." - ) - if len(sort["properties"]) == 0: - raise ValueError("'sort['properties']' cannot be an empty list.") - - if isinstance(sort["order_asc"], bool): - sort["order_asc"] = [sort["order_asc"]] * len(sort["properties"]) - elif not isinstance(sort["order_asc"], list) or not all( - isinstance(x, bool) for x in sort["order_asc"] - ): - raise TypeError( - f"'sort['order_asc']' must be of type boolean or list[bool]. Given type: {type(sort['order_asc'])}." - ) - if len(sort["properties"]) != len(sort["order_asc"]): # type: ignore - raise ValueError( - f"'sort['order_asc']' must be the same length as 'sort['properties']' or a boolean (not in a list). Current length is sort['properties']:{len(sort['properties'])} and sort['order_asc']:{len(sort['order_asc'])}." # type: ignore - ) - if len(sort["order_asc"]) == 0: # type: ignore - raise ValueError("'sort['order_asc']' cannot be an empty list.") - - params["sort"] = ",".join(sort["properties"]) # type: ignore - order = ["asc" if x else "desc" for x in sort["order_asc"]] # type: ignore - params["order"] = ",".join(order) - - try: - response = self._connection.get( - path=path, - params=params, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Could not get object/s.") from conn_err - if response.status_code == 200: - return cast(Dict[str, Any], response.json()) - if response.status_code == 404: - return None - raise UnexpectedStatusCodeException("Get object/s", response) - - def delete( - self, - uuid: Union[str, uuid_lib.UUID], - class_name: Optional[str] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Delete an existing object from Weaviate. - - Parameters - ---------- - uuid : str or uuid.UUID - The ID of the object that should be deleted. - class_name : Optional[str], optional - The class name of the object with UUID `uuid`. Introduced in Weaviate version v1.14.0. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: str, optional - The name of the tenant for which this operation is being performed. - - Examples - -------- - >>> client.data_object.get( - ... uuid="d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617112817487, - "id": "d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - "lastUpdateTimeUnix": 1617112817487, - "properties": { - "age": 46, - "name": "H.P. Lovecraft" - }, - "vectorWeights": null - } - >>> client.data_object.delete( - ... uuid="d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - >>> client.data_object.get( - ... uuid="d842a0f4-ad8c-40eb-80b4-bfefc7b1b530", - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - None - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a none OK status. - TypeError - If parameter has the wrong type. - ValueError - If uuid is not properly formed. - """ - - uuid = get_valid_uuid(uuid) - - is_server_version_14 = self._connection.server_version >= "1.14" - - if class_name is None and is_server_version_14: - warnings.warn( - message=DATA_DEPRECATION_NEW_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if class_name is not None: - if not is_server_version_14: - warnings.warn( - message=DATA_DEPRECATION_OLD_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if not isinstance(class_name, str): - raise TypeError(f"'class_name' must be of type str. Given type: {type(class_name)}") - - if class_name and is_server_version_14: - path = f"/objects/{_capitalize_first_letter(class_name)}/{uuid}" - else: - path = f"/objects/{uuid}" - - params = {} - if consistency_level is not None: - params = {"consistency_level": ConsistencyLevel(consistency_level).value} - if tenant is not None: - params["tenant"] = tenant - try: - response = self._connection.delete( - path=path, - params=params, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Object could not be deleted.") from conn_err - if response.status_code == 204: - # Successfully deleted - return - raise UnexpectedStatusCodeException("Delete object", response) - - def exists( - self, - uuid: Union[str, uuid_lib.UUID], - class_name: Optional[str] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> bool: - """ - Check if the object exist in Weaviate. - - Parameters - ---------- - uuid : str or uuid.UUID - The UUID of the object that may or may not exist within Weaviate. - class_name : Optional[str], optional - The class name of the object with UUID `uuid`. Introduced in Weaviate version 1.14.0. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < 1.14.0, - by default None - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: Optional[str], optional - The name of the tenant for which this operation is being performed. - - Examples - -------- - >>> client.data_object.exists( - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - False - >>> client.data_object.create( - ... data_object = {'name': 'Andrzej Sapkowski', 'age': 72}, - ... class_name = 'Author', - ... uuid = 'e067f671-1202-42c6-848b-ff4d1eb804ab' - ... ) - >>> client.data_object.exists( - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - True - - Returns - ------- - bool - True if object exists, False otherwise. - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a none OK status. - TypeError - If parameter has the wrong type. - ValueError - If uuid is not properly formed. - """ - - is_server_version_14 = self._connection.server_version >= "1.14" - - if class_name is None and is_server_version_14: - warnings.warn( - message=DATA_DEPRECATION_NEW_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if class_name is not None: - if not is_server_version_14: - warnings.warn( - message=DATA_DEPRECATION_OLD_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if not isinstance(class_name, str): - raise TypeError(f"'class_name' must be of type str. Given type: {type(class_name)}") - - if class_name and is_server_version_14: - path = f"/objects/{_capitalize_first_letter(class_name)}/{get_valid_uuid(uuid)}" - else: - path = f"/objects/{get_valid_uuid(uuid)}" - params = {} - if consistency_level is not None: - params = {"consistency_level": ConsistencyLevel(consistency_level).value} - if tenant is not None: - params["tenant"] = tenant - - try: - response = self._connection.head( - path=path, - params=params, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Could not check if object exist.") from conn_err - - if response.status_code == 204: - return True - if response.status_code == 404: - return False - raise UnexpectedStatusCodeException("Object exists", response) - - def validate( - self, - data_object: Union[dict, str], - class_name: str, - uuid: Union[str, uuid_lib.UUID, None] = None, - vector: Optional[Sequence] = None, - ) -> dict: - """ - Validate an object against Weaviate. - - Parameters - ---------- - data_object : dict or str - Object to be validated. - If type is str it should be either an URL or a file. - class_name : str - Name of the class of the object that should be validated. - uuid : str, uuid.UUID or None, optional - The UUID of the object that should be validated against Weaviate. - by default None. - vector: Sequence or None, optional - The embedding of the object that should be validated. - Can be used when: - - a class does not have a vectorization module. - - The given vector was generated using the _identical_ vectorization module that is configured for the - class. In this case this vector takes precedence. - - Supported types are `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, - by default None. - - Examples - -------- - Assume we have a Author class only 'name' property, NO 'age'. - - >>> client1.data_object.validate( - ... data_object = {'name': 'H. Lovecraft'}, - ... class_name = 'Author' - ... ) - {'error': None, 'valid': True} - >>> client1.data_object.validate( - ... data_object = {'name': 'H. Lovecraft', 'age': 46}, - ... class_name = 'Author' - ... ) - { - "error": [ - { - "message": "invalid object: no such prop with name 'age' found in class 'Author' - in the schema. Check your schema files for which properties in this class are - available" - } - ], - "valid": false - } - - Returns - ------- - dict - Validation result. E.g. {"valid": bool, "error": None or list} - - Raises - ------ - TypeError - If argument is of wrong type. - ValueError - If argument contains an invalid value. - weaviate.UnexpectedStatusCodeException - If validating the object against Weaviate failed with a different reason. - requests.ConnectionError - If the network connection to Weaviate fails. - """ - - loaded_data_object = _get_dict_from_object(data_object) - if not isinstance(class_name, str): - raise TypeError(f"Expected class_name of type `str` but was: {type(class_name)}") - - weaviate_obj = { - "class": _capitalize_first_letter(class_name), - "properties": loaded_data_object, - } - - if uuid is not None: - weaviate_obj["id"] = get_valid_uuid(uuid) - - if vector is not None: - weaviate_obj["vector"] = get_vector(vector) - - path = "/objects/validate" - try: - response = self._connection.post(path=path, weaviate_object=weaviate_obj) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Object was not validated against Weaviate." - ) from conn_err - - result: dict = {"error": None} - - if response.status_code == 200: - result["valid"] = True - return result - if response.status_code == 422: - result["valid"] = False - result["error"] = response.json()["error"] - return result - raise UnexpectedStatusCodeException("Validate object", response) - - -def _get_params(additional_properties: Optional[List[str]], with_vector: bool) -> dict: - """ - Get underscore properties in the format accepted by Weaviate. - - Parameters - ---------- - additional_properties : list of str or None - A list of additional properties or None. - with_vector: bool - If True the `vector` property will be returned too. - - Returns - ------- - dict - A dictionary including Weaviate-accepted additional properties - and/or `vector` property. - - Raises - ------ - TypeError - If 'additional_properties' is not of type list. - """ - - params = {} - if additional_properties: - if not isinstance(additional_properties, list): - raise TypeError( - "Additional properties must be of type list " - f"but are {type(additional_properties)}" - ) - params["include"] = ",".join(additional_properties) - - if with_vector: - if "include" in params: - params["include"] = params["include"] + ",vector" - else: - params["include"] = "vector" - return params diff --git a/weaviate/data/references/__init__.py b/weaviate/data/references/__init__.py deleted file mode 100644 index 240d4f09c..000000000 --- a/weaviate/data/references/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Module for adding, deleting and updating references in-between objects. -""" - -__all__ = ["Reference"] - -from .crud_references import Reference diff --git a/weaviate/data/references/crud_references.py b/weaviate/data/references/crud_references.py deleted file mode 100644 index 4b7692af8..000000000 --- a/weaviate/data/references/crud_references.py +++ /dev/null @@ -1,680 +0,0 @@ -""" -Reference class definition. -""" - -import warnings -from typing import Union, Optional, List - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from weaviate.data.replication import ConsistencyLevel -from weaviate.error_msgs import ( - REF_DEPRECATION_NEW_V14_CLS_NS_W, - REF_DEPRECATION_OLD_V14_FROM_CLS_NS_W, - REF_DEPRECATION_OLD_V14_TO_CLS_NS_W, -) -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.util import ( - get_valid_uuid, - _capitalize_first_letter, -) - - -class Reference: - """ - Reference class used to manipulate references within objects. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Reference class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running weaviate instance. - """ - - self._connection = connection - - def delete( - self, - from_uuid: str, - from_property_name: str, - to_uuid: str, - from_class_name: Optional[str] = None, - to_class_name: Optional[str] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Remove a reference to another object. Equal to removing one direction of an edge from the - graph. - - Parameters - ---------- - from_uuid : str - The ID of the object that references another object. - from_property_name : str - The property from which the reference should be deleted. - to_uuid : str - The UUID of the referenced object. - from_class_name : Optional[str], optional - The class name of the object for which to delete the reference (with UUID `from_uuid`), - it is included in Weaviate 1.14.0, where all objects are namespaced by class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - to_class_name : Optional[str], optional - The referenced object class name to which to delete the reference (with UUID `to_uuid`), - it is included in Weaviate 1.14.0, where all objects are namespaced by class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: Optional[str], optional - The name of the tenant for which this operation is being performed. - - Examples - -------- - Assume we have two classes, Author and Book. - - >>> # Create the objects first - >>> client.data_object.create( - ... data_object={'name': 'Ray Bradbury'}, - ... class_name='Author', - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab' - ... ) - >>> client.data_object.create( - ... data_object={'title': 'The Martian Chronicles'}, - ... class_name='Book', - ... uuid='a9c1b714-4f8a-4b01-a930-38b046d69d2d' - ... ) - >>> # Add the cross references - >>> ## Author -> Book - >>> client.data_object.reference.add( - ... from_uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... from_property_name='wroteBooks', - ... to_uuid='a9c1b714-4f8a-4b01-a930-38b046d69d2d', - ... from_class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... to_class_name='Book', # ONLY with Weaviate >= 1.14.0 - ... ) - >>> client.data_object.get( - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617177700595, - "id": "e067f671-1202-42c6-848b-ff4d1eb804ab", - "lastUpdateTimeUnix": 1617177700595, - "properties": { - "name": "Ray Bradbury", - "wroteBooks": [ - { - "beacon": "weaviate://localhost/Book/a9c1b714-4f8a-4b01-a930-38b046d69d2d", - "href": "/v1/objects/Book/a9c1b714-4f8a-4b01-a930-38b046d69d2d" - } - ] - }, - "vectorWeights": null - } - >>> # delete the reference - >>> client.data_object.reference.delete( - ... from_uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... from_property_name='wroteBooks', - ... to_uuid='a9c1b714-4f8a-4b01-a930-38b046d69d2d', - ... from_class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... to_class_name='Book', # ONLY with Weaviate >= 1.14.0 - ... ) - >>> >>> client.data_object.get( - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617177700595, - "id": "e067f671-1202-42c6-848b-ff4d1eb804ab", - "lastUpdateTimeUnix": 1617177864970, - "properties": { - "name": "Ray Bradbury", - "wroteBooks": [] - }, - "vectorWeights": null - } - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - TypeError - If parameter has the wrong type. - ValueError - If uuid is not properly formed. - """ - - is_server_version_14 = self._connection.server_version >= "1.14" - params = {} - if consistency_level is not None: - params["consistency_level"] = ConsistencyLevel(consistency_level).value - if tenant is not None: - params["tenant"] = tenant - - if (from_class_name is None or to_class_name is None) and is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_NEW_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if from_class_name is not None: - if not is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_OLD_V14_FROM_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - _validate_string_arguments( - argument=from_class_name, - argument_name="from_class_name", - ) - if to_class_name is not None: - if not is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_OLD_V14_TO_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - _validate_string_arguments( - argument=to_class_name, - argument_name="to_class_name", - ) - - # Validate and create Beacon - from_uuid = get_valid_uuid(from_uuid) - to_uuid = get_valid_uuid(to_uuid) - _validate_string_arguments( - argument=from_property_name, - argument_name="from_property_name", - ) - - if to_class_name and is_server_version_14: - beacon = _get_beacon( - to_uuid=to_uuid, - class_name=_capitalize_first_letter(to_class_name), - ) - else: - beacon = _get_beacon( - to_uuid=to_uuid, - ) - - if from_class_name and is_server_version_14: - _class_name = _capitalize_first_letter(from_class_name) - path = f"/objects/{_class_name}/{from_uuid}/references/{from_property_name}" - else: - path = f"/objects/{from_uuid}/references/{from_property_name}" - - try: - response = self._connection.delete(path=path, weaviate_object=beacon, params=params) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Reference was not deleted.") from conn_err - if response.status_code == 204: - return - raise UnexpectedStatusCodeException("Delete property reference to object", response) - - def update( - self, - from_uuid: str, - from_property_name: str, - to_uuids: Union[List[str], str], - from_class_name: Optional[str] = None, - to_class_names: Union[List[str], str, None] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Allows to update all references in that property with a new set of references. - All old references will be deleted. - - Parameters - ---------- - from_uuid : str - The object that should have the reference as part of its properties. - Should be in the form of an UUID or in form of an URL. - E.g. - 'http://localhost:8080/v1/objects/Book/fc7eb129-f138-457f-b727-1b29db191a67' - or - 'fc7eb129-f138-457f-b727-1b29db191a67' - from_property_name : str - The name of the property within the object. - to_uuids : list or str - The UUIDs of the objects that should be referenced. - Should be a list of str in the form of an UUID or str in form of an URL. - E.g. - ['http://localhost:8080/v1/objects/Book/fc7eb129-f138-457f-b727-1b29db191a67', ...] - or - ['fc7eb129-f138-457f-b727-1b29db191a67', ...] - If `str` it is converted internally into a list of str. - from_class_name : Optional[str], optional - The class name of the object for which to delete the reference (with UUID `from_uuid`), - it is included in Weaviate 1.14.0, where all objects are namespaced by class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - to_class_names : Union[list, str, None], optional - The referenced objects class name to which to delete the reference (with UUID - `to_uuid`), it is included in Weaviate 1.14.0, where all objects are namespaced by - class name. It can be a single class name (assumes all `to_uuids` are of the same - class) or a list of class names where for each UUID in `to_uuids` we have a class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: Optional[str] - The name of the tenant for which this operation is being performed. - - Examples - -------- - You have data object 1 with reference property `wroteBooks` and currently has one reference - to data object 7. Now you say, I want to update the references of data object 1.wroteBooks - to this list 3,4,9. After the update, the data object 1.wroteBooks is now 3,4,9, but no - longer contains 7. - - >>> client.data_object.get( - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab' - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617177700595, - "id": "e067f671-1202-42c6-848b-ff4d1eb804ab", - "lastUpdateTimeUnix": 1617177700595, - "properties": { - "name": "Ray Bradbury", - "wroteBooks": [ - { - "beacon": "weaviate://localhost/Book/a9c1b714-4f8a-4b01-a930-38b046d69d2d", - "href": "/v1/objects/Book/a9c1b714-4f8a-4b01-a930-38b046d69d2d" - } - ] - }, - "vectorWeights": null - } - Currently there is only one `Book` reference. - Update all the references of the Author for property name `wroteBooks`. - >>> client.data_object.reference.update( - ... from_uuid = 'e067f671-1202-42c6-848b-ff4d1eb804ab', - ... from_property_name = 'wroteBooks', - ... to_uuids = [ - ... '8429f68f-860a-49ea-a50b-1f8789515882', - ... '3e2e6795-298b-47e9-a2cb-3d8a77a24d8a' - ... ], - ... from_class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... to_class_name='Book', # ONLY with Weaviate >= 1.14.0 - ... ) - >>> client.data_object.get( - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab' - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617181292677, - "id": "e067f671-1202-42c6-848b-ff4d1eb804ab", - "lastUpdateTimeUnix": 1617181409405, - "properties": { - "name": "Ray Bradbury", - "wroteBooks": [ - { - "beacon": "weaviate://localhost/Book/8429f68f-860a-49ea-a50b-1f8789515882", - "href": "/v1/objects/Book/8429f68f-860a-49ea-a50b-1f8789515882" - }, - { - "beacon": "weaviate://localhost/Book/3e2e6795-298b-47e9-a2cb-3d8a77a24d8a", - "href": "/v1/objects/Book/3e2e6795-298b-47e9-a2cb-3d8a77a24d8a" - } - ] - }, - "vectorWeights": null - } - All the previous references were removed and now we have only those specified in the - `update` method. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - TypeError - If the parameters are of the wrong type. - ValueError - If the parameters are of the wrong value. - """ - - is_server_version_14 = self._connection.server_version >= "1.14" - params = {} - if consistency_level is not None: - params["consistency_level"] = ConsistencyLevel(consistency_level).value - if tenant is not None: - params["tenant"] = tenant - - if (from_class_name is None or to_class_names is None) and is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_NEW_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if from_class_name is not None: - if not is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_OLD_V14_FROM_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - _validate_string_arguments( - argument=from_class_name, - argument_name="from_class_name", - ) - if to_class_names is not None: - if not is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_OLD_V14_TO_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - - if not isinstance(to_class_names, list): - _validate_string_arguments( - argument=to_class_names, - argument_name="to_class_names", - ) - else: - for to_class_name in to_class_names: - if not isinstance(to_class_name, str): - raise TypeError( - "'to_class_names' must be of type str of List[str]. " - f"Found element of type: {type(to_class_name)}" - ) - if len(to_class_names) == 0: - to_class_names = None - - if not isinstance(to_uuids, list): - to_uuids = [to_uuids] - if isinstance(to_class_names, str): - to_class_names = [to_class_names] * len(to_uuids) - if to_class_names is not None and len(to_uuids) != len(to_class_names): - raise ValueError( - "'to_class_names' and 'to_uuids' have different lengths, they must match." - ) - - # Validate and create Beacon - from_uuid = get_valid_uuid(from_uuid) - _validate_string_arguments( - argument=from_property_name, - argument_name="from_property_name", - ) - beacons = [] - - if to_class_names and is_server_version_14: - for to_uuid, to_class_name in zip(to_uuids, to_class_names): - to_uuid = get_valid_uuid(to_uuid) - beacon = _get_beacon( - to_uuid=to_uuid, - class_name=_capitalize_first_letter(to_class_name), - ) - beacons.append(beacon) - else: - for to_uuid in to_uuids: - to_uuid = get_valid_uuid(to_uuid) - beacon = _get_beacon( - to_uuid=to_uuid, - ) - beacons.append(beacon) - - if from_class_name and is_server_version_14: - _class_name = _capitalize_first_letter(from_class_name) - path = f"/objects/{_class_name}/{from_uuid}/references/{from_property_name}" - else: - path = f"/objects/{from_uuid}/references/{from_property_name}" - - try: - response = self._connection.put( - path=path, - weaviate_object=beacons, - params=params, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Reference was not updated.") from conn_err - if response.status_code == 200: - return - raise UnexpectedStatusCodeException("Update property reference to object", response) - - def add( - self, - from_uuid: str, - from_property_name: str, - to_uuid: str, - from_class_name: Optional[str] = None, - to_class_name: Optional[str] = None, - consistency_level: Optional[ConsistencyLevel] = None, - tenant: Optional[str] = None, - ) -> None: - """ - Allows to link an object to an object uni-directionally. - - Parameters - ---------- - from_uuid : str - The ID of the object that should have the reference as part - of its properties. Should be a plane UUID or an URL. - E.g. - 'http://localhost:8080/v1/objects/Book/fc7eb129-f138-457f-b727-1b29db191a67' - or - 'fc7eb129-f138-457f-b727-1b29db191a67' - from_property_name : str - The name of the property within the object. - to_uuid : str - The UUID of the object that should be referenced. - Should be a plane UUID or an URL. - E.g. - 'http://localhost:8080/v1/objects/Book/fc7eb129-f138-457f-b727-1b29db191a67' - or - 'fc7eb129-f138-457f-b727-1b29db191a67' - from_class_name : Optional[str], optional - The class name of the object for which to delete the reference (with UUID `from_uuid`), - it is included in Weaviate 1.14.0, where all objects are namespaced by class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - to_class_name : Optional[str], optional - The referenced object class name to which to delete the reference (with UUID `to_uuid`), - it is included in Weaviate 1.14.0, where all objects are namespaced by class name. - STRONGLY recommended to set it with Weaviate >= 1.14.0. It will be required in future - versions of Weaviate Server and Clients. Use None value ONLY for Weaviate < v1.14.0, - by default None - consistency_level : Optional[ConsistencyLevel], optional - Can be one of 'ALL', 'ONE', or 'QUORUM'. Determines how many replicas must acknowledge - tenant: Optional[str] - The name of the tenant for which this operation is being performed. - - Examples - -------- - Assume we have two classes, Author and Book. - - >>> # Create the objects first - >>> client.data_object.create( - ... data_object={'name': 'Ray Bradbury'}, - ... class_name='Author', - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab' - ... ) - >>> client.data_object.create( - ... data_object={'title': 'The Martian Chronicles'}, - ... class_name='Book', - ... uuid='a9c1b714-4f8a-4b01-a930-38b046d69d2d' - ... ) - >>> # Add the cross references - >>> ## Author -> Book - >>> client.data_object.reference.add( - ... from_uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... from_property_name='wroteBooks', - ... to_uuid='a9c1b714-4f8a-4b01-a930-38b046d69d2d', - ... from_class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... to_class_name='Book', # ONLY with Weaviate >= 1.14.0 - ... ) - >>> client.data_object.get( - ... uuid='e067f671-1202-42c6-848b-ff4d1eb804ab', - ... class_name='Author', # ONLY with Weaviate >= 1.14.0 - ... ) - { - "additional": {}, - "class": "Author", - "creationTimeUnix": 1617177700595, - "id": "e067f671-1202-42c6-848b-ff4d1eb804ab", - "lastUpdateTimeUnix": 1617177700595, - "properties": { - "name": "Ray Bradbury", - "wroteBooks": [ - { - "beacon": "weaviate://localhost/Book/a9c1b714-4f8a-4b01-a930-38b046d69d2d", - "href": "/v1/objects/Book/a9c1b714-4f8a-4b01-a930-38b046d69d2d" - } - ] - }, - "vectorWeights": null - } - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - TypeError - If the parameters are of the wrong type. - ValueError - If the parameters are of the wrong value. - """ - - is_server_version_14 = self._connection.server_version >= "1.14" - params = {} - if consistency_level is not None: - params["consistency_level"] = ConsistencyLevel(consistency_level).value - if tenant is not None: - params["tenant"] = tenant - - if (from_class_name is None or to_class_name is None) and is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_NEW_V14_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - if from_class_name is not None: - if not is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_OLD_V14_FROM_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - _validate_string_arguments( - argument=from_class_name, - argument_name="from_class_name", - ) - if to_class_name is not None: - if not is_server_version_14: - warnings.warn( - message=REF_DEPRECATION_OLD_V14_TO_CLS_NS_W, - category=DeprecationWarning, - stacklevel=1, - ) - _validate_string_arguments( - argument=to_class_name, - argument_name="to_class_name", - ) - - # Validate and create Beacon - from_uuid = get_valid_uuid(from_uuid) - to_uuid = get_valid_uuid(to_uuid) - _validate_string_arguments( - argument=from_property_name, - argument_name="from_property_name", - ) - - if to_class_name and is_server_version_14: - beacon = _get_beacon( - to_uuid=to_uuid, - class_name=_capitalize_first_letter(to_class_name), - ) - else: - beacon = _get_beacon( - to_uuid=to_uuid, - ) - - if from_class_name and is_server_version_14: - _class_name = _capitalize_first_letter(from_class_name) - path = f"/objects/{_class_name}/{from_uuid}/references/{from_property_name}" - else: - path = f"/objects/{from_uuid}/references/{from_property_name}" - - try: - response = self._connection.post( - path=path, - weaviate_object=beacon, - params=params, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Reference was not added.") from conn_err - if response.status_code == 200: - return - raise UnexpectedStatusCodeException("Add property reference to object", response) - - -def _get_beacon(to_uuid: str, class_name: Optional[str] = None) -> dict: - """ - Get a weaviate-style beacon. - - Parameters - ---------- - to_uuid : str - The UUID to create beacon for. - class_name : Optional[str], optional - The class name of the `to_uuid` object. Used with Weaviate >= 1.14.0. - For Weaviate < 1.14.0 use None value. - - Returns - ------- - dict - Weaviate-style beacon as a dict. - """ - - if class_name is None: - return {"beacon": f"weaviate://localhost/{to_uuid}"} - return {"beacon": f"weaviate://localhost/{class_name}/{to_uuid}"} - - -def _validate_string_arguments(argument: str, argument_name: str) -> None: - """ - Validate string arguments. - - Parameters - ---------- - argument : str - Argument value to be validated. - argument_name : str - Argument name to be included in error message. - - Raises - ------ - TypeError - If 'argument' is not of type str. - """ - - if not isinstance(argument, str): - raise TypeError(f"'{argument_name}' must be of type 'str'. Given type: {type(argument)}") diff --git a/weaviate/data/replication/__init__.py b/weaviate/data/replication/__init__.py deleted file mode 100644 index 5a52bcc03..000000000 --- a/weaviate/data/replication/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Module for managing and facilitating class replication. -""" - -__all__ = ["ConsistencyLevel"] - -from .replication import ConsistencyLevel diff --git a/weaviate/data/replication/replication.py b/weaviate/data/replication/replication.py deleted file mode 100644 index 0d5e2708c..000000000 --- a/weaviate/data/replication/replication.py +++ /dev/null @@ -1,7 +0,0 @@ -from enum import Enum - - -class ConsistencyLevel(str, Enum): - ALL = "ALL" - ONE = "ONE" - QUORUM = "QUORUM" diff --git a/weaviate/gql/__init__.py b/weaviate/gql/__init__.py index 5db7a211d..1b7dfe827 100644 --- a/weaviate/gql/__init__.py +++ b/weaviate/gql/__init__.py @@ -1,8 +1,3 @@ """ GraphQL module used to create `get` and/or `aggregate` GraphQL requests from Weaviate. """ - -__all__ = ["AdditionalProperties", "LinkTo", "Query"] - -from .get import AdditionalProperties, LinkTo -from .query import Query diff --git a/weaviate/gql/aggregate.py b/weaviate/gql/aggregate.py index d86f4a699..2e78b0556 100644 --- a/weaviate/gql/aggregate.py +++ b/weaviate/gql/aggregate.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import List, Optional -from weaviate.connect import Connection from weaviate.util import _capitalize_first_letter, file_encoder_b64, _sanitize_str from .filter import ( Where, @@ -66,7 +65,7 @@ class AggregateBuilder(GraphQL): AggregateBuilder class used to aggregate Weaviate objects. """ - def __init__(self, class_name: str, connection: Connection): + def __init__(self, class_name: str): """ Initialize a AggregateBuilder class instance. @@ -74,11 +73,7 @@ def __init__(self, class_name: str, connection: Connection): ---------- class_name : str Class name of the objects to be aggregated. - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. """ - - super().__init__(connection) self._class_name: str = _capitalize_first_letter(class_name) self._object_limit: Optional[int] = None self._with_meta_count: bool = False @@ -178,62 +173,6 @@ def with_where(self, content: dict) -> "AggregateBuilder": content : dict The where filter to include in the aggregate query. See examples below. - Examples - -------- - The `content` prototype is like this: - - >>> content = { - ... 'operator': '', - ... 'operands': [ - ... { - ... 'path': [path], - ... 'operator': '' - ... '': - ... }, - ... { - ... 'path': [], - ... 'operator': '', - ... '': - ... } - ... ] - ... } - - This is a complete `where` filter but it does not have to be like this all the time. - - Single operand: - - >>> content = { - ... 'path': ["wordCount"], # Path to the property that should be used - ... 'operator': 'GreaterThan', # operator - ... 'valueInt': 1000 # value (which is always = to the type of the path property) - ... } - - Or - - >>> content = { - ... 'path': ["id"], - ... 'operator': 'Equal', - ... 'valueString': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf" - ... } - - Multiple operands: - - >>> content = { - ... 'operator': 'And', - ... 'operands': [ - ... { - ... 'path': ["wordCount"], - ... 'operator': 'GreaterThan', - ... 'valueInt': 1000 - ... }, - ... { - ... 'path': ["wordCount"], - ... 'operator': 'LessThan', - ... 'valueInt': 1500 - ... } - ... ] - ... } - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -291,59 +230,6 @@ def with_near_text(self, content: dict) -> "AggregateBuilder": content : dict The content of the `nearText` filter to set. See examples below. - Examples - -------- - Content full prototype: - - >>> content = { - ... 'concepts': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... 'moveAwayFrom': { # Optional - ... 'concepts': , - ... 'force': - ... }, - ... 'moveTo': { # Optional - ... 'concepts': , - ... 'force': - ... }, - ... 'autocorrect': , # Optional - ... } - - Full content: - - >>> content = { - ... 'concepts': ["fashion"], - ... 'certainty': 0.7, # or 'distance' instead - ... 'moveAwayFrom': { - ... 'concepts': ["finance"], - ... 'force': 0.45 - ... }, - ... 'moveTo': { - ... 'concepts': ["haute couture"], - ... 'force': 0.85 - ... }, - ... 'autocorrect': True - ... } - - Partial content: - - >>> content = { - ... 'concepts': ["fashion"], - ... 'certainty': 0.7, # or 'distance' instead - ... 'moveTo': { - ... 'concepts': ["haute couture"], - ... 'force': 0.85 - ... } - ... } - - Minimal content: - - >>> content = { - ... 'concepts': "fashion" - ... } - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -372,45 +258,6 @@ def with_near_vector(self, content: dict) -> "AggregateBuilder": content : dict The content of the `nearVector` filter to set. See examples below. - Examples - -------- - Content full prototype: - - >>> content = { - ... 'vector' : , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - NOTE: Supported types for 'vector' are `list`, 'numpy.ndarray`, `torch.Tensor` - and `tf.Tensor`. - - Full content: - - >>> content = { - ... 'vector' : [.1, .2, .3, .5], - ... 'certainty': 0.75, # or 'distance' instead - ... } - - Minimal content: - - >>> content = { - ... 'vector' : [.1, .2, .3, .5] - ... } - - Or - - >>> content = { - ... 'vector' : torch.tensor([.1, .2, .3, .5]) - ... } - - Or - - >>> content = { - ... 'vector' : torch.tensor([[.1, .2, .3, .5]]) # it is going to be squeezed. - ... } - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -439,28 +286,6 @@ def with_near_object(self, content: dict) -> "AggregateBuilder": content : dict The content of the `nearObject` filter to set. See examples below. - Examples - -------- - Content prototype: - - >>> content = { - ... 'id': , # OR 'beacon' - ... 'beacon': , # OR 'id' - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'id': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> # alternatively - >>> { - ... 'beacon': "weaviate://localhost/Book/e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf" - ... 'certainty': 0.7 # or 'distance' instead - ... } - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -472,13 +297,11 @@ def with_near_object(self, content: dict) -> "AggregateBuilder": If another 'near' filter was already set. """ - is_server_version_14 = self._connection.server_version >= "1.14" - if self._near is not None: raise AttributeError("Cannot use multiple 'near' filters.") if self._hybrid is not None: raise AttributeError("Cannot use 'near' and 'hybrid' filters simultaneously.") - self._near = NearObject(content, is_server_version_14) + self._near = NearObject(content, True) self._uses_filter = True return self @@ -497,82 +320,6 @@ def with_near_image(self, content: dict, encode: bool = True) -> "AggregateBuild string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). By default True. - Examples - -------- - Content prototype: - - >>> content = { - ... 'image': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'image': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'image': "my_image_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Image')\\ - .with_fields('description')\\ - ... .with_near_image(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_image_file = open("my_image_path.png", "br") - >>> content = { - ... 'image': my_image_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Image')\\ - .with_fields('description')\\ - ... .with_near_image(content, encode=True) # <- encode MUST be set to True - >>> my_image_file.close() - - With `encoded` False: - - >>> from weaviate.util import image_encoder_b64, image_decoder_b64 - >>> encoded_image = image_encoder_b64("my_image_path.png") - >>> content = { - ... 'image': encoded_image, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Image')\\ - .with_fields('description')\\ - ... .with_near_image(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import image_encoder_b64, image_decoder_b64 - >>> with open("my_image_path.png", "br") as my_image_file: - ... encoded_image = image_encoder_b64(my_image_file) - >>> content = { - ... 'image': encoded_image, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Image')\\ - .with_fields('description')\\ - ... .with_near_image(content, encode=False) # <- encode MUST be set to False - - Encode Image yourself: - - >>> import base64 - >>> with open("my_image_path.png", "br") as my_image_file: - ... encoded_image = base64.b64encode(my_image_file.read()).decode("utf-8") - >>> content = { - ... 'image': encoded_image, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Image')\\ - .with_fields('description')\\ - ... .with_near_image(content, encode=False) # <- encode MUST be set to False - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -612,82 +359,6 @@ def with_near_audio(self, content: dict, encode: bool = True) -> "AggregateBuild string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). By default True. - Examples - -------- - Content prototype: - - >>> content = { - ... 'audio': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'audio': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'audio': "my_audio_path.wav", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Audio')\\ - .with_fields('description')\\ - ... .with_near_audio(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_audio_file = open("my_audio_path.wav", "br") - >>> content = { - ... 'audio': my_audio_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Audio')\\ - .with_fields('description')\\ - ... .with_near_audio(content, encode=True) # <- encode MUST be set to True - >>> my_audio_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_audio = file_encoder_b64("my_audio_path.wav") - >>> content = { - ... 'audio': encoded_audio, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Audio')\\ - .with_fields('description')\\ - ... .with_near_audio(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_audio_path.wav", "br") as my_audio_file: - ... encoded_audio = file_encoder_b64(my_audio_file) - >>> content = { - ... 'audio': encoded_audio, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Audio')\\ - .with_fields('description')\\ - ... .with_near_audio(content, encode=False) # <- encode MUST be set to False - - Encode Audio yourself: - - >>> import base64 - >>> with open("my_audio_path.wav", "br") as my_audio_file: - ... encoded_audio = base64.b64encode(my_audio_file.read()).decode("utf-8") - >>> content = { - ... 'audio': encoded_audio, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Audio')\\ - .with_fields('description')\\ - ... .with_near_audio(content, encode=False) # <- encode MUST be set to False - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -728,82 +399,6 @@ def with_near_video(self, content: dict, encode: bool = True) -> "AggregateBuild string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). By default True. - Examples - -------- - Content prototype: - - >>> content = { - ... 'video': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'video': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'video': "my_video_path.avi", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Video')\\ - .with_fields('description')\\ - ... .with_near_video(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_video_file = open("my_video_path.avi", "br") - >>> content = { - ... 'video': my_video_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Video')\\ - .with_fields('description')\\ - ... .with_near_video(content, encode=True) # <- encode MUST be set to True - >>> my_video_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_video = file_encoder_b64("my_video_path.avi") - >>> content = { - ... 'video': encoded_video, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Video')\\ - .with_fields('description')\\ - ... .with_near_video(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64, video_decoder_b64 - >>> with open("my_video_path.avi", "br") as my_video_file: - ... encoded_video = file_encoder_b64(my_video_file) - >>> content = { - ... 'video': encoded_video, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Video')\\ - .with_fields('description')\\ - ... .with_near_video(content, encode=False) # <- encode MUST be set to False - - Encode Video yourself: - - >>> import base64 - >>> with open("my_video_path.avi", "br") as my_video_file: - ... encoded_video = base64.b64encode(my_video_file.read()).decode("utf-8") - >>> content = { - ... 'video': encoded_video, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Video')\\ - .with_fields('description')\\ - ... .with_near_video(content, encode=False) # <- encode MUST be set to False - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -844,82 +439,6 @@ def with_near_depth(self, content: dict, encode: bool = True) -> "AggregateBuild string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). By default True. - Examples - -------- - Content prototype: - - >>> content = { - ... 'depth': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'depth': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'depth': "my_depth_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Depth')\\ - .with_fields('description')\\ - ... .with_near_depth(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_depth_file = open("my_depth_path.png", "br") - >>> content = { - ... 'depth': my_depth_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Depth')\\ - .with_fields('description')\\ - ... .with_near_depth(content, encode=True) # <- encode MUST be set to True - >>> my_depth_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_depth = file_encoder_b64("my_depth_path.png") - >>> content = { - ... 'depth': encoded_depth, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Depth')\\ - .with_fields('description')\\ - ... .with_near_depth(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_depth_path.png", "br") as my_depth_file: - ... encoded_depth = file_encoder_b64(my_depth_file) - >>> content = { - ... 'depth': encoded_depth, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Depth')\\ - .with_fields('description')\\ - ... .with_near_depth(content, encode=False) # <- encode MUST be set to False - - Encode Depth yourself: - - >>> import base64 - >>> with open("my_depth_path.png", "br") as my_depth_file: - ... encoded_depth = base64.b64encode(my_depth_file.read()).decode("utf-8") - >>> content = { - ... 'depth': encoded_depth, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Depth')\\ - .with_fields('description')\\ - ... .with_near_depth(content, encode=False) # <- encode MUST be set to False - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -960,81 +479,6 @@ def with_near_thermal(self, content: dict, encode: bool = True) -> "AggregateBui string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). By default True. - Examples - -------- - Content prototype: - - >>> content = { - ... 'thermal': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'thermal': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'thermal': "my_thermal_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Thermal', 'description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_thermal_file = open("my_thermal_path.png", "br") - >>> content = { - ... 'thermal': my_thermal_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Thermal')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - >>> my_thermal_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_thermal = file_encoder_b64("my_thermal_path.png") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Thermal')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = file_encoder_b64(my_thermal_file) - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Thermal')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - Encode Thermal yourself: - - >>> import base64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = base64.b64encode(my_thermal_file.read()).decode("utf-8") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('Thermal')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - Returns ------- weaviate.gql.aggregate.AggregateBuilder @@ -1075,81 +519,6 @@ def with_near_imu(self, content: dict, encode: bool = True) -> "AggregateBuilder string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). By default True. - Examples - -------- - Content prototype: - - >>> content = { - ... 'thermal': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'thermal': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'thermal': "my_thermal_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('IMU')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_thermal_file = open("my_thermal_path.png", "br") - >>> content = { - ... 'thermal': my_thermal_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('IMU')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - >>> my_thermal_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_thermal = file_encoder_b64("my_thermal_path.png") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('IMU')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = file_encoder_b64(my_thermal_file) - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('IMU')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - Encode IMU yourself: - - >>> import base64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = base64.b64encode(my_thermal_file.read()).decode("utf-8") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.aggregate('IMU')\\ - ... .with_fields('description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False Returns ------- diff --git a/weaviate/gql/filter.py b/weaviate/gql/filter.py index 2ce3cfa18..41881fcd2 100644 --- a/weaviate/gql/filter.py +++ b/weaviate/gql/filter.py @@ -10,11 +10,8 @@ from json import dumps from typing import Any, Tuple, Union -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection from weaviate.error_msgs import FILTER_BEACON_V14_CLS_NS_W -from weaviate.util import get_vector, _sanitize_str, _decode_json_response_dict +from weaviate.util import get_vector, _sanitize_str VALUE_LIST_TYPES = { "valueStringList", @@ -78,18 +75,6 @@ class GraphQL(ABC): A base abstract class for GraphQL commands, such as Get, Aggregate. """ - def __init__(self, connection: Connection): - """ - Initialize a GraphQL abstract class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running weaviate instance. - """ - - self._connection = connection - @abstractmethod def build(self) -> str: """ @@ -102,32 +87,6 @@ def build(self) -> str: The query. """ - def do(self) -> dict: - """ - Builds and runs the query. - - Returns - ------- - dict - The response of the query. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - query = self.build() - try: - response = self._connection.post(path="/graphql", weaviate_object={"query": query}) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Query was not successful.") from conn_err - - res = _decode_json_response_dict(response, "Query was not successful") - assert res is not None - return res - class Filter(ABC): """ diff --git a/weaviate/gql/get.py b/weaviate/gql/get.py deleted file mode 100644 index bb4d5db46..000000000 --- a/weaviate/gql/get.py +++ /dev/null @@ -1,2094 +0,0 @@ -""" -GraphQL `Get` command. -""" - -from dataclasses import dataclass, Field, fields -from enum import Enum -from json import dumps -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union - -import grpc # type: ignore - -from weaviate import util -from weaviate.connect import Connection -from weaviate.data.replication import ConsistencyLevel -from weaviate.exceptions import AdditionalPropertiesException -from weaviate.gql.filter import ( - Where, - NearText, - NearVector, - GraphQL, - NearObject, - Filter, - Ask, - NearImage, - NearVideo, - NearAudio, - NearThermal, - NearDepth, - NearIMU, - MediaType, - Sort, -) -from weaviate.proto.v1 import search_get_pb2 -from weaviate.str_enum import BaseEnum -from weaviate.types import UUID -from weaviate.util import ( - image_encoder_b64, - _capitalize_first_letter, - get_valid_uuid, - file_encoder_b64, -) -from weaviate.warnings import _Warnings - - -@dataclass -class BM25: - query: str - properties: Optional[List[str]] - - def __str__(self) -> str: - ret = f"query: {util._sanitize_str(self.query)}" - if self.properties is not None and len(self.properties) > 0: - props = '","'.join(self.properties) - ret += f', properties: ["{props}"]' - return "bm25:{" + ret + "}" - - -class HybridFusion(str, BaseEnum): - RANKED = "rankedFusion" - RELATIVE_SCORE = "relativeScoreFusion" - - -@dataclass -class Hybrid: - query: str - alpha: Optional[float] - vector: Optional[List[float]] - properties: Optional[List[str]] - fusion_type: Optional[HybridFusion] - - def __str__(self) -> str: - ret = f"query: {util._sanitize_str(self.query)}" - if self.vector is not None: - ret += f", vector: {self.vector}" - if self.alpha is not None: - ret += f", alpha: {self.alpha}" - if self.properties is not None and len(self.properties) > 0: - props = '","'.join(self.properties) - ret += f', properties: ["{props}"]' - if self.fusion_type is not None: - if isinstance(self.fusion_type, Enum): - ret += f", fusionType: {self.fusion_type.value}" - else: - ret += f", fusionType: {self.fusion_type}" - - return "hybrid:{" + ret + "}" - - -@dataclass -class GroupBy: - path: List[str] - groups: int - objects_per_group: int - - def __str__(self) -> str: - props = '","'.join(self.path) - return f'groupBy:{{path:["{props}"], groups:{self.groups}, objectsPerGroup:{self.objects_per_group}}}' - - -@dataclass -class LinkTo: - link_on: str - linked_class: str - properties: Sequence[Union[str, "LinkTo"]] - - def __str__(self) -> str: - props = " ".join(str(x) for x in self.properties) - return self.link_on + "{... on " + self.linked_class + "{" + props + "}}" - - -PROPERTIES = Union[Sequence[Union[str, LinkTo]], str] - - -@dataclass -class AdditionalProperties: - uuid: bool = False - vector: bool = False - creationTimeUnix: bool = False - lastUpdateTimeUnix: bool = False - distance: bool = False - certainty: bool = False - score: bool = False - explainScore: bool = False - - def __str__(self) -> str: - additional_props: List[str] = [] - cls_fields: Tuple[Field, ...] = fields(self.__class__) - for field in cls_fields: - if issubclass(field.type, bool): # type: ignore - enabled: bool = getattr(self, field.name) - if enabled: - name = field.name - if field.name == "uuid": # id is reserved python name - name = "id" - additional_props.append(name) - if len(additional_props) > 0: - return " _additional{" + " ".join(additional_props) + "} " - else: - return "" - - -class GetBuilder(GraphQL): - """ - GetBuilder class used to create GraphQL queries. - """ - - def __init__(self, class_name: str, properties: Optional[PROPERTIES], connection: Connection): - """ - Initialize a GetBuilder class instance. - - Parameters - ---------- - class_name : str - Class name of the objects to interact with. - properties : str or list of str - Properties of the objects to interact with. - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - - Raises - ------ - TypeError - If argument/s is/are of wrong type. - """ - - super().__init__(connection) - - if not isinstance(class_name, str): - raise TypeError(f"class name must be of type str but was {type(class_name)}") - if properties is None: - properties = [] - if isinstance(properties, str): - properties = [properties] - if not isinstance(properties, list): - raise TypeError( - "properties must be of type str, " f"list of str or None but was {type(properties)}" - ) - - self._properties: Sequence[Union[str, LinkTo]] = [] - for prop in properties: - if not isinstance(prop, str) and not isinstance(prop, LinkTo): - raise TypeError("All the `properties` must be of type `str` or Reference!") - self._properties.append(prop) - - self._class_name: str = _capitalize_first_letter(class_name) - self._additional: dict = {"__one_level": set()} - # '__one_level' refers to the additional properties that are just a single word, not a dict - # thus '__one_level', only one level of complexity - self._additional_dataclass: Optional[AdditionalProperties] = None - self._where: Optional[Where] = None # To store the where filter if it is added - self._limit: Optional[int] = None # To store the limit filter if it is added - self._offset: Optional[str] = None # To store the offset filter if it is added - self._after: Optional[str] = None # To store the offset filter if it is added - self._near_clause: Optional[Filter] = ( - None # To store the `near`/`ask` clause if it is added - ) - self._contains_filter = False # true if any filter is added - self._sort: Optional[Sort] = None - self._bm25: Optional[BM25] = None - self._hybrid: Optional[Hybrid] = None - self._group_by: Optional[GroupBy] = None - self._alias: Optional[str] = None - self._tenant: Optional[str] = None - self._autocut: Optional[int] = None - self._consistency_level: Optional[str] = None - - def with_autocut(self, autocut: int) -> "GetBuilder": - """Cuts off irrelevant results based on "jumps" in scores.""" - if not isinstance(autocut, int): - raise TypeError("autocut must be of type int") - - self._autocut = autocut - self._contains_filter = True - return self - - def with_tenant(self, tenant: str) -> "GetBuilder": - """Sets a tenant for the query.""" - if not isinstance(tenant, str): - raise TypeError("tenant must be of type str") - - self._tenant = tenant - self._contains_filter = True - return self - - def with_after(self, after_uuid: UUID) -> "GetBuilder": - """Can be used to extract all elements by giving the last ID from the previous "page". - - Requires limit to be set but cannot be combined with any other filters or search. Part of the Cursor API. - """ - if not isinstance(after_uuid, UUID.__args__): # type: ignore # __args__ is workaround for python 3.8 - raise TypeError("after_uuid must be of type UUID (str or uuid.UUID)") - - self._after = f'after: "{get_valid_uuid(after_uuid)}"' - self._contains_filter = True - return self - - def with_where(self, content: dict) -> "GetBuilder": - """ - Set `where` filter. - - Parameters - ---------- - content : dict - The content of the `where` filter to set. See examples below. - - Examples - -------- - The `content` prototype is like this: - - >>> content = { - ... 'operator': '', - ... 'operands': [ - ... { - ... 'path': [path], - ... 'operator': '' - ... '': - ... }, - ... { - ... 'path': [], - ... 'operator': '', - ... '': - ... } - ... ] - ... } - - This is a complete `where` filter but it does not have to be like this all the time. - - Single operand: - - >>> content = { - ... 'path': ["wordCount"], # Path to the property that should be used - ... 'operator': 'GreaterThan', # operator - ... 'valueInt': 1000 # value (which is always = to the type of the path property) - ... } - - Or - - >>> content = { - ... 'path': ["id"], - ... 'operator': 'Equal', - ... 'valueString': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf" - ... } - - Multiple operands: - - >>> content = { - ... 'operator': 'And', - ... 'operands': [ - ... { - ... 'path': ["wordCount"], - ... 'operator': 'GreaterThan', - ... 'valueInt': 1000 - ... }, - ... { - ... 'path': ["wordCount"], - ... 'operator': 'LessThan', - ... 'valueInt': 1500 - ... } - ... ] - ... } - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - """ - - self._where = Where(content) - self._contains_filter = True - return self - - @property - def name(self) -> str: - return self._alias if self._alias else self._class_name - - def with_near_text(self, content: dict) -> "GetBuilder": - """ - Set `nearText` filter. This filter can be used with text modules (text2vec). - E.g.: text2vec-contextionary, text2vec-transformers. - NOTE: The 'autocorrect' field is enabled only with the `text-spellcheck` Weaviate module. - - Parameters - ---------- - content : dict - The content of the `nearText` filter to set. See examples below. - - Examples - -------- - Content full prototype: - - >>> content = { - ... 'concepts': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... 'moveAwayFrom': { # Optional - ... 'concepts': , - ... 'force': - ... }, - ... 'moveTo': { # Optional - ... 'concepts': , - ... 'force': - ... }, - ... 'autocorrect': , # Optional - ... } - - Full content: - - >>> content = { - ... 'concepts': ["fashion"], - ... 'certainty': 0.7, # or 'distance' - ... 'moveAwayFrom': { - ... 'concepts': ["finance"], - ... 'force': 0.45 - ... }, - ... 'moveTo': { - ... 'concepts': ["haute couture"], - ... 'force': 0.85 - ... }, - ... 'autocorrect': True - ... } - - Partial content: - - >>> content = { - ... 'concepts': ["fashion"], - ... 'certainty': 0.7, # or 'distance' - ... 'moveTo': { - ... 'concepts': ["haute couture"], - ... 'force': 0.85 - ... } - ... } - - Minimal content: - - >>> content = { - ... 'concepts': "fashion" - ... } - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - self._near_clause = NearText(content) - self._contains_filter = True - return self - - def with_near_vector(self, content: dict) -> "GetBuilder": - """ - Set `nearVector` filter. - - Parameters - ---------- - content : dict - The content of the `nearVector` filter to set. See examples below. - - Examples - -------- - Content full prototype: - - >>> content = { - ... 'vector' : , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - NOTE: Supported types for 'vector' are `list`, 'numpy.ndarray`, `torch.Tensor` - and `tf.Tensor`. - - Full content: - - >>> content = { - ... 'vector' : [.1, .2, .3, .5], - ... 'certainty': 0.75, # or 'distance' - ... } - - Minimal content: - - >>> content = { - ... 'vector' : [.1, .2, .3, .5] - ... } - - Or - - >>> content = { - ... 'vector' : torch.tensor([.1, .2, .3, .5]) - ... } - - Or - - >>> content = { - ... 'vector' : torch.tensor([[.1, .2, .3, .5]]) # it is going to be squeezed. - ... } - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - self._near_clause = NearVector(content) - self._contains_filter = True - return self - - def with_near_object(self, content: dict) -> "GetBuilder": - """ - Set `nearObject` filter. - - Parameters - ---------- - content : dict - The content of the `nearObject` filter to set. See examples below. - - Examples - -------- - Content prototype: - - >>> { - ... 'id': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - >>> # alternatively - >>> { - ... 'beacon': "weaviate://localhost/ClassName/e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf" - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - is_server_version_14 = self._connection.server_version >= "1.14" - - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - self._near_clause = NearObject(content, is_server_version_14) - self._contains_filter = True - return self - - def with_near_image(self, content: dict, encode: bool = True) -> "GetBuilder": - """ - Set `nearImage` filter. - - Parameters - ---------- - content : dict - The content of the `nearImage` filter to set. See examples below. - encode : bool, optional - Whether to encode the `content["image"]` to base64 and convert to string. If True, the - `content["image"]` can be an image path or a file opened in binary read mode. If False, - the `content["image"]` MUST be a base64 encoded string (NOT bytes, i.e. NOT binary - string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). - By default True. - - Examples - -------- - Content prototype: - - >>> content = { - ... 'image': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'image': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'image': "my_image_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Image', 'description')\\ - ... .with_near_image(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_image_file = open("my_image_path.png", "br") - >>> content = { - ... 'image': my_image_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Image', 'description')\\ - ... .with_near_image(content, encode=True) # <- encode MUST be set to True - >>> my_image_file.close() - - With `encoded` False: - - >>> from weaviate.util import image_encoder_b64, image_decoder_b64 - >>> encoded_image = image_encoder_b64("my_image_path.png") - >>> content = { - ... 'image': encoded_image, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Image', 'description')\\ - ... .with_near_image(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import image_encoder_b64, image_decoder_b64 - >>> with open("my_image_path.png", "br") as my_image_file: - ... encoded_image = image_encoder_b64(my_image_file) - >>> content = { - ... 'image': encoded_image, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Image', 'description')\\ - ... .with_near_image(content, encode=False) # <- encode MUST be set to False - - Encode Image yourself: - - >>> import base64 - >>> with open("my_image_path.png", "br") as my_image_file: - ... encoded_image = base64.b64encode(my_image_file.read()).decode("utf-8") - >>> content = { - ... 'image': encoded_image, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Image', 'description')\\ - ... .with_near_image(content, encode=False) # <- encode MUST be set to False - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - if encode: - content["image"] = image_encoder_b64(content["image"]) - self._near_clause = NearImage(content) - self._contains_filter = True - return self - - def with_near_audio(self, content: dict, encode: bool = True) -> "GetBuilder": - """ - Set `nearAudio` filter. - - Parameters - ---------- - content : dict - The content of the `nearAudio` filter to set. See examples below. - encode : bool, optional - Whether to encode the `content["audio"]` to base64 and convert to string. If True, the - `content["audio"]` can be an audio path or a file opened in binary read mode. If False, - the `content["audio"]` MUST be a base64 encoded string (NOT bytes, i.e. NOT binary - string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). - By default True. - - Examples - -------- - Content prototype: - - >>> content = { - ... 'audio': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'audio': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'audio': "my_audio_path.wav", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Audio', 'description')\\ - ... .with_near_audio(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_audio_file = open("my_audio_path.wav", "br") - >>> content = { - ... 'audio': my_audio_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Audio', 'description')\\ - ... .with_near_audio(content, encode=True) # <- encode MUST be set to True - >>> my_audio_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_audio = file_encoder_b64("my_audio_path.wav") - >>> content = { - ... 'audio': encoded_audio, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Audio', 'description')\\ - ... .with_near_audio(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_audio_path.wav", "br") as my_audio_file: - ... encoded_audio = file_encoder_b64(my_audio_file) - >>> content = { - ... 'audio': encoded_audio, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Audio', 'description')\\ - ... .with_near_audio(content, encode=False) # <- encode MUST be set to False - - Encode Audio yourself: - - >>> import base64 - >>> with open("my_audio_path.wav", "br") as my_audio_file: - ... encoded_audio = base64.b64encode(my_audio_file.read()).decode("utf-8") - >>> content = { - ... 'audio': encoded_audio, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Audio', 'description')\\ - ... .with_near_audio(content, encode=False) # <- encode MUST be set to False - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - self._media_type = MediaType.AUDIO - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - if encode: - content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) - self._near_clause = NearAudio(content) - self._contains_filter = True - return self - - def with_near_video(self, content: dict, encode: bool = True) -> "GetBuilder": - """ - Set `nearVideo` filter. - - Parameters - ---------- - content : dict - The content of the `nearVideo` filter to set. See examples below. - encode : bool, optional - Whether to encode the `content["video"]` to base64 and convert to string. If True, the - `content["video"]` can be an video path or a file opened in binary read mode. If False, - the `content["video"]` MUST be a base64 encoded string (NOT bytes, i.e. NOT binary - string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). - By default True. - - Examples - -------- - Content prototype: - - >>> content = { - ... 'video': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'video': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'video': "my_video_path.avi", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Video', 'description')\\ - ... .with_near_video(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_video_file = open("my_video_path.avi", "br") - >>> content = { - ... 'video': my_video_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Video', 'description')\\ - ... .with_near_video(content, encode=True) # <- encode MUST be set to True - >>> my_video_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_video = file_encoder_b64("my_video_path.avi") - >>> content = { - ... 'video': encoded_video, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Video', 'description')\\ - ... .with_near_video(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64, video_decoder_b64 - >>> with open("my_video_path.avi", "br") as my_video_file: - ... encoded_video = file_encoder_b64(my_video_file) - >>> content = { - ... 'video': encoded_video, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Video', 'description')\\ - ... .with_near_video(content, encode=False) # <- encode MUST be set to False - - Encode Video yourself: - - >>> import base64 - >>> with open("my_video_path.avi", "br") as my_video_file: - ... encoded_video = base64.b64encode(my_video_file.read()).decode("utf-8") - >>> content = { - ... 'video': encoded_video, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Video', 'description')\\ - ... .with_near_video(content, encode=False) # <- encode MUST be set to False - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - self._media_type = MediaType.VIDEO - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - if encode: - content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) - self._near_clause = NearVideo(content) - self._contains_filter = True - return self - - def with_near_depth(self, content: dict, encode: bool = True) -> "GetBuilder": - """ - Set `nearDepth` filter. - - Parameters - ---------- - content : dict - The content of the `nearDepth` filter to set. See examples below. - encode : bool, optional - Whether to encode the `content["depth"]` to base64 and convert to string. If True, the - `content["depth"]` can be an depth path or a file opened in binary read mode. If False, - the `content["depth"]` MUST be a base64 encoded string (NOT bytes, i.e. NOT binary - string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). - By default True. - - Examples - -------- - Content prototype: - - >>> content = { - ... 'depth': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'depth': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'depth': "my_depth_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Depth', 'description')\\ - ... .with_near_depth(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_depth_file = open("my_depth_path.png", "br") - >>> content = { - ... 'depth': my_depth_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Depth', 'description')\\ - ... .with_near_depth(content, encode=True) # <- encode MUST be set to True - >>> my_depth_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_depth = file_encoder_b64("my_depth_path.png") - >>> content = { - ... 'depth': encoded_depth, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Depth', 'description')\\ - ... .with_near_depth(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_depth_path.png", "br") as my_depth_file: - ... encoded_depth = file_encoder_b64(my_depth_file) - >>> content = { - ... 'depth': encoded_depth, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Depth', 'description')\\ - ... .with_near_depth(content, encode=False) # <- encode MUST be set to False - - Encode Depth yourself: - - >>> import base64 - >>> with open("my_depth_path.png", "br") as my_depth_file: - ... encoded_depth = base64.b64encode(my_depth_file.read()).decode("utf-8") - >>> content = { - ... 'depth': encoded_depth, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Depth', 'description')\\ - ... .with_near_depth(content, encode=False) # <- encode MUST be set to False - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - self._media_type = MediaType.DEPTH - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - if encode: - content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) - self._near_clause = NearDepth(content) - self._contains_filter = True - return self - - def with_near_thermal(self, content: dict, encode: bool = True) -> "GetBuilder": - """ - Set `nearThermal` filter. - - Parameters - ---------- - content : dict - The content of the `nearThermal` filter to set. See examples below. - encode : bool, optional - Whether to encode the `content["thermal"]` to base64 and convert to string. If True, the - `content["thermal"]` can be an thermal path or a file opened in binary read mode. If False, - the `content["thermal"]` MUST be a base64 encoded string (NOT bytes, i.e. NOT binary - string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). - By default True. - - Examples - -------- - Content prototype: - - >>> content = { - ... 'thermal': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'thermal': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'thermal': "my_thermal_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Thermal', 'description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_thermal_file = open("my_thermal_path.png", "br") - >>> content = { - ... 'thermal': my_thermal_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Thermal', 'description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - >>> my_thermal_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_thermal = file_encoder_b64("my_thermal_path.png") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Thermal', 'description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = file_encoder_b64(my_thermal_file) - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Thermal', 'description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - Encode Thermal yourself: - - >>> import base64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = base64.b64encode(my_thermal_file.read()).decode("utf-8") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('Thermal', 'description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - self._media_type = MediaType.THERMAL - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - if encode: - content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) - self._near_clause = NearThermal(content) - self._contains_filter = True - return self - - def with_near_imu(self, content: dict, encode: bool = True) -> "GetBuilder": - """ - Set `nearIMU` filter. - - Parameters - ---------- - content : dict - The content of the `nearIMU` filter to set. See examples below. - encode : bool, optional - Whether to encode the `content["thermal"]` to base64 and convert to string. If True, the - `content["thermal"]` can be an thermal path or a file opened in binary read mode. If False, - the `content["thermal"]` MUST be a base64 encoded string (NOT bytes, i.e. NOT binary - string that looks like this: b'BASE64ENCODED' but simple 'BASE64ENCODED'). - By default True. - - Examples - -------- - Content prototype: - - >>> content = { - ... 'thermal': , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... } - - >>> { - ... 'thermal': "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf", - ... 'certainty': 0.7 # or 'distance' - ... } - - With `encoded` True: - - >>> content = { - ... 'thermal': "my_thermal_path.png", - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('IMU', 'description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - - OR - - >>> my_thermal_file = open("my_thermal_path.png", "br") - >>> content = { - ... 'thermal': my_thermal_file, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('IMU', 'description')\\ - ... .with_near_thermal(content, encode=True) # <- encode MUST be set to True - >>> my_thermal_file.close() - - With `encoded` False: - - >>> from weaviate.util import file_encoder_b64 - >>> encoded_thermal = file_encoder_b64("my_thermal_path.png") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('IMU', 'description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - OR - - >>> from weaviate.util import file_encoder_b64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = file_encoder_b64(my_thermal_file) - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('IMU', 'description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - Encode IMU yourself: - - >>> import base64 - >>> with open("my_thermal_path.png", "br") as my_thermal_file: - ... encoded_thermal = base64.b64encode(my_thermal_file.read()).decode("utf-8") - >>> content = { - ... 'thermal': encoded_thermal, - ... 'certainty': 0.7 # or 'distance' instead - ... } - >>> query = client.query.get('IMU', 'description')\\ - ... .with_near_thermal(content, encode=False) # <- encode MUST be set to False - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - AttributeError - If another 'near' filter was already set. - """ - - self._media_type = MediaType.IMU - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - if encode: - content[self._media_type.value] = file_encoder_b64(content[self._media_type.value]) - self._near_clause = NearIMU(content) - self._contains_filter = True - return self - - def with_limit(self, limit: int) -> "GetBuilder": - """ - The limit of objects returned. - - Parameters - ---------- - limit : int - The max number of objects returned. - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - ValueError - If 'limit' is non-positive. - """ - - if limit < 1: - raise ValueError("limit cannot be non-positive (limit >=1).") - - self._limit = limit - self._contains_filter = True - return self - - def with_offset(self, offset: int) -> "GetBuilder": - """ - The offset of objects returned, i.e. the starting index of the returned objects should be - used in conjunction with the `with_limit` method. - - Parameters - ---------- - offset : int - The offset used for the returned objects. - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - ValueError - If 'offset' is non-positive. - """ - - if offset < 0: - raise ValueError("offset cannot be non-positive (offset >=0).") - - self._offset = f"offset: {offset} " - self._contains_filter = True - return self - - def with_ask(self, content: dict) -> "GetBuilder": - """ - Ask a question for which weaviate will retrieve the answer from your data. - This filter can be used only with QnA module: qna-transformers. - NOTE: The 'autocorrect' field is enabled only with the `text-spellcheck` Weaviate module. - - Parameters - ---------- - content : dict - The content of the `ask` filter to set. See examples below. - - Examples - -------- - Content full prototype: - - >>> content = { - ... 'question' : , - ... # certainty ONLY with `cosine` distance specified in the schema - ... 'certainty': , # Optional, either 'certainty' OR 'distance' - ... 'distance': , # Optional, either 'certainty' OR 'distance' - ... 'properties': # Optional - ... 'autocorrect': , # Optional - ... } - - Full content: - - >>> content = { - ... 'question' : "What is the NLP?", - ... 'certainty': 0.7, # or 'distance' - ... 'properties': ['body'] # search the answer in these properties only. - ... 'autocorrect': True - ... } - - Minimal content: - - >>> content = { - ... 'question' : "What is the NLP?" - ... } - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - """ - - if self._near_clause is not None: - raise AttributeError( - "Cannot use multiple 'near' filters, or a 'near' filter along" - " with a 'ask' filter!" - ) - self._near_clause = Ask(content) - self._contains_filter = True - return self - - def with_additional( - self, - properties: Union[ - List, str, Dict[str, Union[List[str], str]], Tuple[dict, dict], AdditionalProperties - ], - ) -> "GetBuilder": - """ - Add additional properties (i.e. properties from `_additional` clause). See Examples below. - If the the 'properties' is of data type `str` or `list` of `str` then the method is - idempotent, if it is of type `dict` or `tuple` then the exiting property is going to be - replaced. To set the setting of one of the additional property use the `tuple` data type - where `properties` look like this (clause: dict, settings: dict) where the 'settings' are - the properties inside the '(...)' of the clause. See Examples for more information. - - Parameters - ---------- - properties : str, list of str, dict[str, str], dict[str, list of str] or tuple[dict, dict] - The additional properties to include in the query. Can be property name as `str`, - a list of property names, a dictionary (clause without settings) where the value is a - `str` or list of `str`, or a `tuple` of 2 elements: - (clause: Dict[str, str or list[str]], settings: Dict[str, Any]) - where the 'clause' is the property and all its sub-properties and the 'settings' is the - setting of the property, i.e. everything that is inside the `(...)` right after the - property name. See Examples below. - - Examples - -------- - - >>> # single additional property with this GraphQL query - >>> ''' - ... { - ... Get { - ... Article { - ... title - ... author - ... _additional { - ... id - ... } - ... } - ... } - ... } - ... ''' - >>> client.query\\ - ... .get('Article', ['title', 'author'])\\ - ... .with_additional('id']) # argument as `str` - - >>> # multiple additional property with this GraphQL query - >>> ''' - ... { - ... Get { - ... Article { - ... title - ... author - ... _additional { - ... id - ... certainty - ... } - ... } - ... } - ... } - ... ''' - >>> client.query\\ - ... .get('Article', ['title', 'author'])\\ - ... .with_additional(['id', 'certainty']) # argument as `List[str]` - - >>> # additional properties as clause with this GraphQL query - >>> ''' - ... { - ... Get { - ... Article { - ... title - ... author - ... _additional { - ... classification { - ... basedOn - ... classifiedFields - ... completed - ... id - ... scope - ... } - ... } - ... } - ... } - ... } - ... ''' - >>> client.query\\ - ... .get('Article', ['title', 'author'])\\ - ... .with_additional( - ... { - ... 'classification' : ['basedOn', 'classifiedFields', 'completed', 'id'] - ... } - ... ) # argument as `dict[str, List[str]]` - >>> # or with this GraphQL query - >>> ''' - ... { - ... Get { - ... Article { - ... title - ... author - ... _additional { - ... classification { - ... completed - ... } - ... } - ... } - ... } - ... } - ... ''' - >>> client.query\\ - ... .get('Article', ['title', 'author'])\\ - ... .with_additional( - ... { - ... 'classification' : 'completed' - ... } - ... ) # argument as `Dict[str, str]` - - Consider the following GraphQL clause: - - >>> ''' - ... { - ... Get { - ... Article { - ... title - ... author - ... _additional { - ... token ( - ... properties: ["content"] - ... limit: 10 - ... certainty: 0.8 - ... ) { - ... certainty - ... endPosition - ... entity - ... property - ... startPosition - ... word - ... } - ... } - ... } - ... } - ... } - ... ''' - - Then the python translation of this is the following: - - >>> clause = { - ... 'token': [ # if only one, can be passes as `str` - ... 'certainty', - ... 'endPosition', - ... 'entity', - ... 'property', - ... 'startPosition', - ... 'word', - ... ] - ... } - >>> settings = { - ... 'properties': ["content"], # is required - ... 'limit': 10, # optional, int - ... 'certainty': 0.8 # optional, float - ... } - >>> client.query\\ - ... .get('Article', ['title', 'author'])\\ - ... .with_additional( - ... (clause, settings) - ... ) # argument as `Tuple[Dict[str, List[str]], Dict[str, Any]]` - - If the desired clause does not match any example above, then the clause can always be - converted to string before passing it to the `.with_additional()` method. - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - - Raises - ------ - TypeError - If one of the property is not of a correct data type. - """ - if isinstance(properties, AdditionalProperties): - if len(self._additional) > 1 or len(self._additional["__one_level"]) > 0: - raise AdditionalPropertiesException( - str(self._additional), str(self._additional_dataclass) - ) - self._additional_dataclass = properties - return self - elif self._additional_dataclass is not None: - raise AdditionalPropertiesException( - str(self._additional), str(self._additional_dataclass) - ) - - if isinstance(properties, str): - self._additional["__one_level"].add(properties) - return self - - if isinstance(properties, list): - for prop in properties: - if not isinstance(prop, str): - raise TypeError( - "If type of 'properties' is `list` then all items must be of type `str`!" - ) - self._additional["__one_level"].add(prop) - return self - - if isinstance(properties, tuple): - self._tuple_to_dict(properties) - return self - - if not isinstance(properties, dict): - raise TypeError( - "The 'properties' argument must be either of type `str`, `list`, `dict` or " - f"`tuple`! Given: {type(properties)}" - ) - - # only `dict` type here - for key, values in properties.items(): - if not isinstance(key, str): - raise TypeError( - "If type of 'properties' is `dict` then all keys must be of type `str`!" - ) - self._additional[key] = set() - if isinstance(values, str): - self._additional[key].add(values) - continue - if not isinstance(values, list): - raise TypeError( - "If type of 'properties' is `dict` then all the values must be either of type " - f"`str` or `list` of `str`! Given: {type(values)}!" - ) - if len(values) == 0: - raise ValueError( - "If type of 'properties' is `dict` and a value is of type `list` then at least" - " one element should be present!" - ) - for value in values: - if not isinstance(value, str): - raise TypeError( - "If type of 'properties' is `dict` and a value is of type `list` then all " - "items must be of type `str`!" - ) - self._additional[key].add(value) - return self - - def with_sort(self, content: Union[list, dict]) -> "GetBuilder": - """ - Sort objects based on specific field/s. Multiple sort fields can be used, the objects are - going to be sorted according to order of the sort configs passed. This method can be called - multiple times and it does not overwrite the last entry but appends it to the previous - ones, see examples below. - - Parameters - ---------- - content : Union[list, dict] - The content of the Sort filter. Can be a single Sort configuration or a list of - configurations. - - Examples - -------- - The `content` should have this form: - - >>> content = { - ... 'path': ['name'] # Path to the property that should be used - ... 'order': 'asc' # Sort order, possible values: asc, desc - ... } - >>> client.query.get('Author', ['name', 'address'])\\ - ... .with_sort(content) - - Or a list of sort configurations: - - >>> content = [ - ... { - ... 'path': ['name'] # Path to the property that should be used - ... 'order': 'asc' # Sort order, possible values: asc, desc - ... }, - ... 'path': ['address'] # Path to the property that should be used - ... 'order': 'desc' # Sort order, possible values: asc, desc - ... } - ... ] - - If we have a list we can add it in 2 ways. - Pass the list: - - >>> client.query.get('Author', ['name', 'address'])\\ - ... .with_sort(content) - - Or one configuration at a time: - - >>> client.query.get('Author', ['name', 'address'])\\ - ... .with_sort(content[0]) - ... .with_sort(content[1]) - - It is possible to call this method multiple times with lists only too. - - - Returns - ------- - weaviate.gql.get.GetBuilder - The updated GetBuilder. - """ - - if self._sort is None: - self._sort = Sort(content=content) - self._contains_filter = True - else: - self._sort.add(content=content) - return self - - def with_bm25(self, query: str, properties: Optional[List[str]] = None) -> "GetBuilder": - """Add BM25 query to search the inverted index for keywords. - - Parameters - ---------- - query: str - The query to search for. - properties: Optional[List[str]] - Which properties should be searched. If 'None' or empty all properties will be searched. By default, None - """ - self._bm25 = BM25(query, properties) - self._contains_filter = True - - return self - - def with_hybrid( - self, - query: str, - alpha: Optional[float] = None, - vector: Optional[List[float]] = None, - properties: Optional[List[str]] = None, - fusion_type: Optional[HybridFusion] = None, - ) -> "GetBuilder": - """Get objects using bm25 and vector, then combine the results using a reciprocal ranking algorithm. - - Parameters - ---------- - query: str - The query to search for. - alpha: Optional[float] - Factor determining how BM25 and vector search are weighted. If 'None' the weaviate default of 0.75 is used. - By default, None - alpha = 0 -> bm25, alpha=1 -> vector search - vector: Optional[List[float]] - Vector that is searched for. If 'None', weaviate will use the configured text-to-vector module to create a - vector from the "query" field. - By default, None - properties: Optional[List[str]]: - Which properties should be searched by BM25. Does not have any effect for vector search. If None or empty - all properties are searched. - fusion_type: Optional[HybridFusionType]: - Which fusion type should be used to merge keyword and vector search. - """ - self._hybrid = Hybrid(query, alpha, vector, properties, fusion_type) - self._contains_filter = True - return self - - def with_group_by( - self, properties: List[str], groups: int, objects_per_group: int - ) -> "GetBuilder": - """Retrieve groups of objects from Weaviate. - - Note that the return values must be set using .with_additional() to see the output. - - Parameters - ---------- - properties: List[str] - Properties to group by - groups: int - Maximum number of groups - objects_per_group: int - Maximum number of objects per group - - """ - self._group_by = GroupBy(properties, groups, objects_per_group) - self._contains_filter = True - return self - - def with_generate( - self, - single_prompt: Optional[str] = None, - grouped_task: Optional[str] = None, - grouped_properties: Optional[List[str]] = None, - ) -> "GetBuilder": - """Generate responses using the OpenAI generative search. - - At least one of the two parameters must be set. - - Parameters - ---------- - grouped_task: Optional[str] - The task to generate a grouped response. - grouped_properties: Optional[List[str]]: - The properties whose contents are added to the prompts. If None or empty, - all text properties contents are added. - single_prompt: Optional[str] - The prompt to generate a single response. - """ - if single_prompt is None and grouped_task is None: - raise TypeError( - "Either parameter grouped_result_task or single_result_prompt must be not None." - ) - if (single_prompt is not None and not isinstance(single_prompt, str)) or ( - grouped_task is not None and not isinstance(grouped_task, str) - ): - raise TypeError("prompts and tasks must be of type str.") - - if self._connection.server_version < "1.17.3": - _Warnings.weaviate_too_old_for_openai(self._connection.server_version) - - results: List[str] = ["error"] - task_and_prompt = "" - if single_prompt is not None: - results.append("singleResult") - task_and_prompt += f"singleResult:{{prompt:{util._sanitize_str(single_prompt)}}}" - if grouped_task is not None or ( - grouped_properties is not None and len(grouped_properties) > 0 - ): - results.append("groupedResult") - args = [] - if grouped_task is not None: - args.append(f"task:{util._sanitize_str(grouped_task)}") - if grouped_properties is not None and len(grouped_properties) > 0: - props = '","'.join(grouped_properties) - args.append(f'properties:["{props}"]') - task_and_prompt += f'groupedResult:{{{",".join(args)}}}' - - self._additional["__one_level"].add(f'generate({task_and_prompt}){{{" ".join(results)}}}') - - return self - - def with_alias( - self, - alias: str, - ) -> "GetBuilder": - """Gives an alias for the query. Needs to be used if 'multi_get' requests the same 'class_name' twice. - - Parameters - ---------- - alias: str - The alias for the query. - """ - - self._alias = alias - return self - - def with_consistency_level(self, consistency_level: ConsistencyLevel) -> "GetBuilder": - """Set the consistency level for the request.""" - - self._consistency_level = f"consistencyLevel: {consistency_level.value} " - self._contains_filter = True - return self - - def build(self, wrap_get: bool = True) -> str: - """ - Build query filter as a string. - - Parameters - ---------- - wrap_get: bool - A boolean to decide wether {Get{...}} is placed around the query. Useful for multi_get. - - Returns - ------- - str - The GraphQL query as a string. - """ - if wrap_get: - query = "{Get{" - else: - query = "" - - if self._alias is not None: - query += self._alias + ": " - query += self._class_name - if self._contains_filter: - query += "(" - if self._where is not None: - query += str(self._where) - if self._limit is not None: - query += f"limit: {self._limit} " - if self._offset is not None: - query += self._offset - if self._near_clause is not None: - query += str(self._near_clause) - if self._sort is not None: - query += str(self._sort) - if self._bm25 is not None: - query += str(self._bm25) - if self._hybrid is not None: - query += str(self._hybrid) - if self._group_by is not None: - query += str(self._group_by) - if self._after is not None: - query += self._after - if self._consistency_level is not None: - query += self._consistency_level - if self._tenant is not None: - query += f'tenant: "{self._tenant}"' - if self._autocut is not None: - query += f"autocut: {self._autocut}" - - query += ")" - - additional_props = self._additional_to_str() - - if not (additional_props or self._properties): - raise AttributeError( - "No 'properties' or 'additional properties' specified to be returned. " - "At least one should be included." - ) - - properties = " ".join(str(x) for x in self._properties) + self._additional_to_str() - query += "{" + properties + "}" - if wrap_get: - query += "}}" - return query - - def do(self) -> dict: - """ - Builds and runs the query. - - Returns - ------- - dict - The response of the query. - - Raises - ------ - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - grpc_enabled = ( # only implemented for some scenarios - self._connection.grpc_stub is not None - and ( - self._near_clause is None - or isinstance(self._near_clause, NearVector) - or isinstance(self._near_clause, NearObject) - ) - and len(self._additional) == 1 - and ( - len(self._additional["__one_level"]) == 0 or "id" in self._additional["__one_level"] - ) - and self._offset is None - and self._sort is None - and self._where is None - and self._after is None - and all( - "..." not in prop and "_additional" not in prop - for prop in self._properties - if isinstance(prop, str) - ) # no ref props as strings - ) - if grpc_enabled: - metadata: Union[Tuple, Tuple[Tuple[Literal["authorization"], str]]] = () - access_token = self._connection.get_current_bearer_token() - if len(access_token) > 0: - metadata = (("authorization", access_token),) - - try: - res, _ = self._connection.grpc_stub.Search.with_call( # type: ignore - search_get_pb2.SearchRequest( - collection=self._class_name, - limit=self._limit, - near_vector=( - search_get_pb2.NearVector( - vector=self._near_clause.content["vector"], - certainty=self._near_clause.content.get("certainty", None), - distance=self._near_clause.content.get("distance", None), - ) - if self._near_clause is not None - and isinstance(self._near_clause, NearVector) - else None - ), - near_object=( - search_get_pb2.NearObject( - id=self._near_clause.content["id"], - certainty=self._near_clause.content.get("certainty", None), - distance=self._near_clause.content.get("distance", None), - ) - if self._near_clause is not None - and isinstance(self._near_clause, NearObject) - else None - ), - properties=self._convert_references_to_grpc(self._properties), - metadata=( - search_get_pb2.MetadataRequest( - uuid=self._additional_dataclass.uuid, - vector=self._additional_dataclass.vector, - creation_time_unix=self._additional_dataclass.creationTimeUnix, - last_update_time_unix=self._additional_dataclass.lastUpdateTimeUnix, - distance=self._additional_dataclass.distance, - explain_score=self._additional_dataclass.explainScore, - score=self._additional_dataclass.score, - ) - if self._additional_dataclass is not None - else None - ), - bm25_search=( - search_get_pb2.BM25( - properties=self._bm25.properties, query=self._bm25.query - ) - if self._bm25 is not None - else None - ), - hybrid_search=( - search_get_pb2.Hybrid( - properties=self._hybrid.properties, - query=self._hybrid.query, - alpha=self._hybrid.alpha, - vector=self._hybrid.vector, - ) - if self._hybrid is not None - else None - ), - ), - metadata=metadata, - ) - - objects = [] - for result in res.results: - obj = self._convert_references_to_grpc_result(result.properties) - additional = self._extract_additional_properties(result.metadata) - if len(additional) > 0: - obj["_additional"] = additional - objects.append(obj) - - results: Union[Dict[str, Dict[str, Dict[str, List]]], Dict[str, List]] = { - "data": {"Get": {self._class_name: objects}} - } - - except grpc.RpcError as e: - results = {"errors": [e.details()]} # pyright: ignore - return results - else: - return super().do() - - def _extract_additional_properties( - self, props: "search_get_pb2.MetadataResult" - ) -> Dict[str, str]: - additional_props: Dict[str, Any] = {} - if self._additional_dataclass is None: - return additional_props - - if self._additional_dataclass.uuid: - additional_props["id"] = props.id - if self._additional_dataclass.vector: - additional_props["vector"] = ( - [float(num) for num in props.vector] if len(props.vector) > 0 else None - ) - if self._additional_dataclass.distance: - additional_props["distance"] = props.distance if props.distance_present else None - if self._additional_dataclass.certainty: - additional_props["certainty"] = props.certainty if props.certainty_present else None - if self._additional_dataclass.creationTimeUnix: - additional_props["creationTimeUnix"] = ( - str(props.creation_time_unix) if props.creation_time_unix_present else None - ) - if self._additional_dataclass.lastUpdateTimeUnix: - additional_props["lastUpdateTimeUnix"] = ( - str(props.last_update_time_unix) if props.last_update_time_unix_present else None - ) - if self._additional_dataclass.score: - additional_props["score"] = props.score if props.score_present else None - if self._additional_dataclass.explainScore: - additional_props["explainScore"] = ( - props.explain_score if props.explain_score_present else None - ) - return additional_props - - def _convert_references_to_grpc_result( - self, properties: "search_get_pb2.PropertiesResult" - ) -> Dict: - result: Dict[str, Any] = {} - for name, non_ref_prop in properties.non_ref_properties.items(): - result[name] = non_ref_prop - - for ref_prop in properties.ref_props: - result[ref_prop.prop_name] = [ - self._convert_references_to_grpc_result(prop) for prop in ref_prop.properties - ] - - return result - - def _convert_references_to_grpc( - self, properties: Sequence[Union[LinkTo, str]] - ) -> "search_get_pb2.PropertiesRequest": - return search_get_pb2.PropertiesRequest( - non_ref_properties=[prop for prop in properties if isinstance(prop, str)], - ref_properties=[ - search_get_pb2.RefPropertiesRequest( - target_collection=prop.linked_class, - reference_property=prop.link_on, - properties=self._convert_references_to_grpc(prop.properties), - ) - for prop in properties - if isinstance(prop, LinkTo) - ], - ) - - def _additional_to_str(self) -> str: - """ - Convert `self._additional` attribute to a `str`. - - Returns - ------- - str - The converted self._additional. - """ - if self._additional_dataclass is not None: - return str(self._additional_dataclass) - - str_to_return = " _additional {" - - has_values = False - for one_level in sorted(self._additional["__one_level"]): - has_values = True - str_to_return += one_level + " " - - for key, values in sorted(self._additional.items(), key=lambda key_value: key_value[0]): - if key == "__one_level": - continue - has_values = True - str_to_return += key + " {" - for value in sorted(values): - str_to_return += value + " " - str_to_return += "} " - - if has_values is False: - return "" - return str_to_return + "}" - - def _tuple_to_dict(self, tuple_value: tuple) -> None: - """ - Convert the tuple data type argument to a dictionary. - - Parameters - ---------- - tuple_value : tuple - The tuple value as (clause: , settings: ). - - Raises - ------ - ValueError - If 'tuple_value' does not have exactly 2 elements. - TypeError - If the configuration of the 'tuple_value' is not correct. - """ - - if len(tuple_value) != 2: - raise ValueError( - "If type of 'properties' is `tuple` then it should have length 2: " - "(clause: , settings: )" - ) - - clause, settings = tuple_value - if not isinstance(clause, dict) or not isinstance(settings, dict): - raise TypeError( - "If type of 'properties' is `tuple` then it should have this data type: " - "(, )" - ) - if len(clause) != 1: - raise ValueError( - "If type of 'properties' is `tuple` then the 'clause' (first element) should " - f"have only one key. Given: {len(clause)}" - ) - if len(settings) == 0: - raise ValueError( - "If type of 'properties' is `tuple` then the 'settings' (second element) should " - f"have at least one key. Given: {len(settings)}" - ) - - clause_key, values = list(clause.items())[0] - - if not isinstance(clause_key, str): - raise TypeError( - "If type of 'properties' is `tuple` then first element's key should be of type " - "`str`!" - ) - - clause_with_settings = clause_key + "(" - try: - for key, value in sorted(settings.items(), key=lambda key_value: key_value[0]): - if not isinstance(key, str): - raise TypeError( - "If type of 'properties' is `tuple` then the second elements () " - "should have all the keys of type `str`!" - ) - clause_with_settings += key + ": " + dumps(value) + " " - except TypeError: - raise TypeError( - "If type of 'properties' is `tuple` then the second elements () " - "should have all the keys of type `str`!" - ) from None - clause_with_settings += ")" - - self._additional[clause_with_settings] = set() - if isinstance(values, str): - self._additional[clause_with_settings].add(values) - return - if not isinstance(values, list): - raise TypeError( - "If type of 'properties' is `tuple` then first element's dict values must be " - f"either of type `str` or `list` of `str`! Given: {type(values)}!" - ) - if len(values) == 0: - raise ValueError( - "If type of 'properties' is `tuple` and first element's dict value is of type " - "`list` then at least one element should be present!" - ) - for value in values: - if not isinstance(value, str): - raise TypeError( - "If type of 'properties' is `tuple` and first element's dict value is of type " - " `list` then all items must be of type `str`!" - ) - self._additional[clause_with_settings].add(value) diff --git a/weaviate/gql/multi_get.py b/weaviate/gql/multi_get.py deleted file mode 100644 index 2f7c375dc..000000000 --- a/weaviate/gql/multi_get.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -GraphQL `Get` command. -""" - -from typing import List -from weaviate.gql.filter import ( - GraphQL, -) -from weaviate.connect import Connection -from .get import GetBuilder - - -class MultiGetBuilder(GraphQL): - """ - GetBuilder class used to create GraphQL queries. - """ - - def __init__(self, get_builder: List[GetBuilder], connection: Connection): - """ - Initialize a MultiGetBuilder class instance. - - Parameters - ---------- - get_builder : list of GetBuilder - GetBuilder objects for a single request each. - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - - Examples - -------- - To create a 'multi_get' object using several 'get' request at the same time use: - - >>> client.query.multi_get( - ... [ - ... client.query.get("Ship", ["name"]).with_alias("one"), - ... client.query.get("Ship", ["size"]).with_alias("two"), - ... client.query.get("Person", ["name"]) - ... ] - with_alias() needs to be used if the same 'class_name' is used twice during the same 'multi_get' request. - - Raises - ------ - TypeError - If 'get_builder' is of wrong type. - """ - get_names = [] - super().__init__(connection) - if not isinstance(get_builder, List): - raise TypeError(f"get_builder must be of type List but was {type(get_builder)}") - for get in get_builder: - if not isinstance(get, GetBuilder): - raise TypeError( - f"All objects in 'get_builder' must be of type 'GetBuilder' but at least one object was {type(get)}" - ) - if get.name not in get_names: - get_names.append(get.name) - else: - raise TypeError( - f"Objects in 'get_builder' can not have duplicate names but two were named: '{get.name}'. Queries can be renamed using an alias." - ) - self.get_builder: List[GetBuilder] = get_builder - - def build(self) -> str: - """ - Build query filter as a string. - - Returns - ------- - str - The GraphQL query as a string. - """ - query = "{Get{" - - for get in self.get_builder: - query += get.build(wrap_get=False) - return query + "}}" diff --git a/weaviate/gql/query.py b/weaviate/gql/query.py deleted file mode 100644 index 2f0845018..000000000 --- a/weaviate/gql/query.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -GraphQL query module. -""" - -from typing import List, Any, Dict, Optional - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from .aggregate import AggregateBuilder -from .get import GetBuilder, PROPERTIES -from .multi_get import MultiGetBuilder -from ..util import _decode_json_response_dict - - -class Query: - """ - Query class used to make `get` and/or `aggregate` GraphQL queries. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Classification class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - """ - - self._connection = connection - - def get( - self, - class_name: str, - properties: Optional[PROPERTIES] = None, - ) -> GetBuilder: - """ - Instantiate a GetBuilder for GraphQL `get` requests. - - Parameters - ---------- - class_name : str - Class name of the objects to interact with. - properties : list of str and ReferenceProperty, str or None - Properties of the objects to get, by default None - - Returns - ------- - GetBuilder - A GetBuilder to make GraphQL `get` requests from weaviate. - """ - return GetBuilder(class_name, properties, self._connection) - - def multi_get( - self, - get_builder: List[GetBuilder], - ) -> MultiGetBuilder: - """ - Instantiate a MultiGetBuilder for GraphQL `multi_get` requests. - Bundles multiple get requests into one. - - Parameters - ---------- - get_builder : list of GetBuilder - List of GetBuilder objects for a single request each. - - Returns - ------- - MultiGetBuilder - A MultiGetBuilder to make GraphQL `get` multiple requests from weaviate. - """ - - return MultiGetBuilder(get_builder, self._connection) - - def aggregate(self, class_name: str) -> AggregateBuilder: - """ - Instantiate an AggregateBuilder for GraphQL `aggregate` requests. - - Parameters - ---------- - class_name : str - Class name of the objects to be aggregated. - - Returns - ------- - AggregateBuilder - An AggregateBuilder to make GraphQL `aggregate` requests from weaviate. - """ - - return AggregateBuilder(class_name, self._connection) - - def raw(self, gql_query: str) -> Dict[str, Any]: - """ - Allows to send simple graph QL string queries. - Be cautious of injection risks when generating query strings. - - Parameters - ---------- - gql_query : str - GraphQL query as a string. - - Returns - ------- - dict - Data response of the query. - - Examples - -------- - >>> query = \""" - ... { - ... Get { - ... Article(limit: 2) { - ... title - ... hasAuthors { - ... ... on Author { - ... name - ... } - ... } - ... } - ... } - ... } - ... \""" - >>> client.query.raw(query) - { - "data": { - "Get": { - "Article": [ - { - "hasAuthors": [ - { - "name": "Jonathan Wilson" - } - ], - "title": "Sergio Ag\u00fcero has been far more than a great goalscorer for - Manchester City" - }, - { - "hasAuthors": [ - { - "name": "Emma Elwick-Bates" - } - ], - "title": "At Swarovski, Giovanna Engelbert Is Crafting Jewels As Exuberantly - Joyful As She Is" - } - ] - } - }, - "errors": null - } - - Raises - ------ - TypeError - If 'gql_query' is not of type str. - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.UnexpectedStatusCodeException - If weaviate reports a none OK status. - """ - - if not isinstance(gql_query, str): - raise TypeError("Query is expected to be a string") - - json_query = {"query": gql_query} - - try: - response = self._connection.post(path="/graphql", weaviate_object=json_query) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Query not executed.") from conn_err - - res = _decode_json_response_dict(response, "GQL query failed") - assert res is not None - return res diff --git a/weaviate/schema/__init__.py b/weaviate/schema/__init__.py deleted file mode 100644 index fdbd158fb..000000000 --- a/weaviate/schema/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Module used to manipulate schemas. -""" - -__all__ = ["Schema", "Tenant", "TenantActivityStatus"] - -from .crud_schema import Schema, Tenant, TenantActivityStatus diff --git a/weaviate/schema/crud_schema.py b/weaviate/schema/crud_schema.py deleted file mode 100644 index f544a28c5..000000000 --- a/weaviate/schema/crud_schema.py +++ /dev/null @@ -1,1059 +0,0 @@ -""" -Schema class definition. -""" - -from dataclasses import dataclass -from enum import Enum -from typing import Any, Union, Optional, List, Dict, cast - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.schema.properties import Property -from weaviate.util import ( - _get_dict_from_object, - _is_sub_schema, - _capitalize_first_letter, - _decode_json_response_dict, - _decode_json_response_list, -) - -CLASS_KEYS = { - "class", - "vectorIndexType", - "vectorIndexConfig", - "moduleConfig", - "description", - "vectorizer", - "properties", - "invertedIndexConfig", - "shardingConfig", - "replicationConfig", - "multiTenancyConfig", -} - -PROPERTY_KEYS = { - "dataType", - "name", - "moduleConfig", - "description", - "indexInverted", - "tokenization", - "indexFilterable", - "indexSearchable", -} - -_PRIMITIVE_WEAVIATE_TYPES_SET = { - "string", - "string[]", - "int", - "int[]", - "boolean", - "boolean[]", - "number", - "number[]", - "date", - "date[]", - "text", - "text[]", - "geoCoordinates", - "blob", - "phoneNumber", - "uuid", - "uuid[]", - "object", - "object[]", -} - - -class TenantActivityStatus(str, Enum): - """ - TenantActivityStatus class used to describe the activity status of a tenant in Weaviate. - - Attributes - ---------- - HOT: The tenant is fully active and can be used. - COLD: The tenant is not active, files stored locally. - """ - - HOT = "HOT" - COLD = "COLD" - - -@dataclass -class Tenant: - """ - Tenant class used to describe a tenant in Weaviate. - - Attributes - ---------- - activity_status : TenantActivityStatus, optional - default: "HOT" - name: the name of the tenant. - """ - - name: str - activity_status: TenantActivityStatus = TenantActivityStatus.HOT - - def _to_weaviate_object(self) -> Dict[str, str]: - return { - "activityStatus": self.activity_status, - "name": self.name, - } - - @classmethod - def _from_weaviate_object(cls, weaviate_object: Dict[str, Any]) -> "Tenant": - return cls( - name=weaviate_object["name"], - activity_status=TenantActivityStatus(weaviate_object.get("activityStatus", "HOT")), - ) - - -class Schema: - """ - Schema class used to interact and manipulate schemas or classes. - - Attributes - ---------- - property : weaviate.schema.properties.Property - A Property object to create new schema property/ies. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Schema class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running Weaviate instance. - """ - - self._connection = connection - self.property = Property(self._connection) - - def create(self, schema: Union[dict, str]) -> None: - """ - Create the schema of the Weaviate instance, with all classes at once. - - Parameters - ---------- - schema : dict or str - Schema as a Python dict, or the path to a JSON file, or the URL of a JSON file. - - Examples - -------- - >>> article_class = { - ... "class": "Article", - ... "description": "An article written by an Author", - ... "properties": [ - ... { - ... "name": "title", - ... "dataType": ["text"], - ... "description": "The title the article", - ... }, - ... { - ... "name": "hasAuthors", - ... "dataType": ["Author"], - ... "description": "Authors this article has", - ... } - ... ] - ... } - >>> author_class = { - ... "class": "Author", - ... "description": "An Author class to store the author information", - ... "properties": [ - ... { - ... "name": "name", - ... "dataType": ["text"], - ... "description": "The name of the author", - ... }, - ... { - ... "name": "wroteArticles", - ... "dataType": ["Article"], - ... "description": "The articles of the author", - ... } - ... ] - ... } - >>> client.schema.create({"classes": [article_class, author_class]}) - - If you have your schema saved in the './schema/my_schema.json' you can create it - directly from the file. - - >>> client.schema.create('./schema/my_schema.json') - - Raises - ------ - TypeError - If the 'schema' is neither a string nor a dict. - ValueError - If 'schema' can not be converted into a Weaviate schema. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - weaviate.SchemaValidationException - If the 'schema' could not be validated against the standard format. - """ - - loaded_schema = _get_dict_from_object(schema) - self._create_classes_with_primitives(loaded_schema["classes"]) - self._create_complex_properties_from_classes(loaded_schema["classes"]) - - def create_class(self, schema_class: Union[dict, str]) -> None: - """ - Create a single class as part of the schema in Weaviate. - - Parameters - ---------- - schema_class : dict or str - Class as a Python dict, or the path to a JSON file, or the URL of a JSON file. - - Examples - -------- - >>> author_class_schema = { - ... "class": "Author", - ... "description": "An Author class to store the author information", - ... "properties": [ - ... { - ... "name": "name", - ... "dataType": ["text"], - ... "description": "The name of the author", - ... }, - ... { - ... "name": "wroteArticles", - ... "dataType": ["Article"], - ... "description": "The articles of the author", - ... } - ... ] - ... } - >>> client.schema.create_class(author_class_schema) - - If you have your class schema saved in the './schema/my_schema.json' you can create it - directly from the file. - - >>> client.schema.create_class('./schema/my_schema.json') - - Raises - ------ - TypeError - If the 'schema_class' is neither a string nor a dict. - ValueError - If 'schema_class' can not be converted into a Weaviate schema. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - weaviate.SchemaValidationException - If the 'schema_class' could not be validated against the standard format. - """ - - loaded_schema_class = _get_dict_from_object(schema_class) - self._create_class_with_primitives(loaded_schema_class) - self._create_complex_properties_from_class(loaded_schema_class) - - def delete_class(self, class_name: str) -> None: - """ - Delete a schema class from Weaviate. This deletes all associated data. - - Parameters - ---------- - class_name : str - The class that should be deleted from Weaviate. - - Examples - -------- - >>> client.schema.delete_class('Author') - - Raises - ------ - TypeError - If 'class_name' argument not of type str. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - if not isinstance(class_name, str): - raise TypeError(f"Class name was {type(class_name)} instead of str") - - path = f"/schema/{_capitalize_first_letter(class_name)}" - try: - response = self._connection.delete(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Deletion of class.") from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Delete class from schema", response) - - def delete_all(self) -> None: - """ - Remove the entire schema from the Weaviate instance and all data associated with it. - - Examples - -------- - >>> client.schema.delete_all() - """ - - schema = self.get() - classes = schema.get("classes", []) - for _class in classes: - self.delete_class(_class["class"]) - - def exists(self, class_name: str) -> bool: - """ - Check if class exists in Weaviate. - - Parameters - ---------- - class_name : str - The class whose existence is being checked. - - Examples - -------- - >>> client.schema.exists(class_name="Exists") - True - - >>> client.schema.exists(class_name="DoesNotExists") - False - - Returns - ------- - bool - True if the class exists, - False otherwise. - """ - - if not isinstance(class_name, str): - raise TypeError( - f"'class_name' argument must be of type `str`! Given type: {type(class_name)}." - ) - - path = f"/schema/{_capitalize_first_letter(class_name)}" - - try: - response = self._connection.get(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Checking class existence could not be done." - ) from conn_err - if response.status_code == 200: - return True - elif response.status_code == 404: - return False - - raise UnexpectedStatusCodeException("Check if class exists", response) - - def contains(self, schema: Optional[Union[dict, str]] = None) -> bool: - """ - Check if Weaviate already contains a schema. - - Parameters - ---------- - schema : dict or str, optional - Schema as a Python dict, or the path to a JSON file or the URL of a JSON file. - If a schema is given it is checked if this specific schema is already loaded. - It will test only this schema. If the given schema is a subset of the loaded - schema it will still return true, by default None. - - Examples - -------- - >>> schema = client.schema.get() - >>> client.schema.contains(schema) - True - >>> schema = client.schema.get() - >>> schema['classes'].append( - { - "class": "Animal", - "description": "An Animal", - "properties": [ - { - "name": "type", - "dataType": ["text"], - "description": "The animal type", - } - ] - } - ) - >>> client.schema.contains(schema) - False - - Returns - ------- - bool - True if a schema is present, - False otherwise. - """ - - loaded_schema = self.get() - - if schema is not None: - sub_schema = _get_dict_from_object(schema) - return _is_sub_schema(sub_schema, loaded_schema) - - if len(loaded_schema["classes"]) == 0: - return False - return True - - def update_config(self, class_name: str, config: dict) -> None: - """ - Update a schema configuration for a specific class. - - Parameters - ---------- - class_name : str - The class for which to update the schema configuration. - config : dict - The configurations to update (MUST follow schema format). - - Example - ------- - In the example below we have a Weaviate instance with a class 'Test'. - - >>> client.schema.get('Test') - { - 'class': 'Test', - ... - 'vectorIndexConfig': { - 'ef': -1, - ... - }, - ... - } - >>> client.schema.update_config( - ... class_name='Test', - ... config={ - ... 'vectorIndexConfig': { - ... 'ef': 100, - ... } - ... } - ... ) - >>> client.schema.get('Test') - { - 'class': 'Test', - ... - 'vectorIndexConfig': { - 'ef': 100, - ... - }, - ... - } - - NOTE: When updating schema configuration, the 'config' MUST be sub-set of the schema, - starting at the top level. In the example above we update 'ef' value, and for this we - included the 'vectorIndexConfig' top level too. - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - class_name = _capitalize_first_letter(class_name) - class_schema = self.get(class_name) - new_class_schema = _update_nested_dict(class_schema, config) - - path = "/schema/" + class_name - try: - response = self._connection.put(path=path, weaviate_object=new_class_schema) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Class schema configuration could not be updated." - ) from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Update class schema configuration", response) - - def get(self, class_name: Optional[str] = None) -> dict: - """ - Get the schema from Weaviate. - - Parameters - ---------- - class_name : str, optional - The class for which to return the schema. If NOT provided the whole schema is returned, - otherwise only the schema of this class is returned. By default None. - - Returns - ------- - dict - A dict containing the schema. The schema may be empty. - To see if a schema has already been loaded, use the `contains` method. - - Examples - -------- - No schema present in client - - >>> client.schema.get() - {'classes': []} - - Schema present in client - - >>> client.schema.get() - { - "classes": [ - { - "class": "Animal", - "description": "An Animal", - "invertedIndexConfig": { - "cleanupIntervalSeconds": 60 - }, - "properties": [ - { - "dataType": ["text"], - "description": "The animal type", - "name": "type" - } - ], - "vectorIndexConfig": { - "cleanupIntervalSeconds": 300, - "maxConnections": 64, - "efConstruction": 128, - "vectorCacheMaxObjects": 500000 - }, - "vectorIndexType": "hnsw", - "vectorizer": "text2vec-contextionary", - "replicationConfig": { - "factor": 1, - } - } - ] - } - - >>> client.schema.get('Animal') - { - "class": "Animal", - "description": "An Animal", - "invertedIndexConfig": { - "cleanupIntervalSeconds": 60 - }, - "properties": [ - { - "dataType": ["text"], - "description": "The animal type", - "name": "type" - } - ], - "vectorIndexConfig": { - "cleanupIntervalSeconds": 300, - "maxConnections": 64, - "efConstruction": 128, - "vectorCacheMaxObjects": 500000 - }, - "vectorIndexType": "hnsw", - "vectorizer": "text2vec-contextionary", - "replicationConfig": { - "factor": 1, - } - } - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - path = "/schema" - if class_name is not None: - if not isinstance(class_name, str): - raise TypeError( - "'class_name' argument must be of type `str`! " - f"Given type: {type(class_name)}" - ) - path = f"/schema/{_capitalize_first_letter(class_name)}" - - try: - response = self._connection.get(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Schema could not be retrieved.") from conn_err - - res = _decode_json_response_dict(response, "Get schema") - assert res is not None - return res - - def get_class_shards(self, class_name: str) -> list: - """ - Get the status of all shards in an index. - - Parameters - ---------- - class_name : str - The class for which to return the status of all shards in an index. - - Returns - ------- - list - The list of shards configuration. - - Examples - -------- - Schema contains a single class: Article - - >>> client.schema.get_class_shards('Article') - [{'name': '2rPgsA2yngW3', 'status': 'READY'}] - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - if not isinstance(class_name, str): - raise TypeError( - "'class_name' argument must be of type `str`! " f"Given type: {type(class_name)}." - ) - path = f"/schema/{_capitalize_first_letter(class_name)}/shards" - - try: - response = self._connection.get(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Class shards' status could not be retrieved due to connection error." - ) from conn_err - - res = _decode_json_response_list(response, "Get shards' status") - assert res is not None - return res - - def update_class_shard( - self, - class_name: str, - status: str, - shard_name: Optional[str] = None, - ) -> list: - """ - Get the status of all shards in an index. - - Parameters - ---------- - class_name : str - The class for which to update the status of all shards in an index. - status : str - The new status of the shard. The available options are: 'READY' and 'READONLY'. - shard_name : str or None, optional - The shard name for which to update the status of the class of the shard. If None then - all the shards are going to be updated to the 'status'. By default None. - - Returns - ------- - list - The updated statuses. - - Examples - -------- - Schema contains a single class: Article - - >>> client.schema.get_class_shards('Article') - [{'name': 'node1', 'status': 'READY'}, {'name': 'node2', 'status': 'READY'}] - - For a specific shard: - - >>> client.schema.update_class_shard('Article', 'READONLY', 'node2') - {'status': 'READONLY'} - >>> client.schema.get_class_shards('Article') - [{'name': 'node1', 'status': 'READY'}, {'name': 'node2', 'status': 'READONLY'}] - - For all shards of the class: - - >>> client.schema.update_class_shard('Article', 'READONLY') - [{'status': 'READONLY'},{'status': 'READONLY'}] - >>> client.schema.get_class_shards('Article') - [{'name': 'node1', 'status': 'READONLY'}, {'name': 'node2', 'status': 'READONLY'}] - - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - if not isinstance(class_name, str): - raise TypeError( - "'class_name' argument must be of type `str`! " f"Given type: {type(class_name)}." - ) - if not isinstance(shard_name, str) and shard_name is not None: - raise TypeError( - "'shard_name' argument must be of type `str`! " f"Given type: {type(shard_name)}." - ) - if not isinstance(status, str): - raise TypeError( - "'status' argument must be of type `str`! " f"Given type: {type(status)}." - ) - - if shard_name is None: - shards_config = self.get_class_shards( - class_name=class_name, - ) - shard_names = [shard_config["name"] for shard_config in shards_config] - else: - shard_names = [shard_name] - - data = {"status": status} - - to_return = [] - - for _shard_name in shard_names: - path = f"/schema/{_capitalize_first_letter(class_name)}/shards/{_shard_name}" - try: - response = self._connection.put( - path=path, - weaviate_object=data, - ) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - f"Class shards' status could not be updated for shard '{_shard_name}' due to " - "connection error." - ) from conn_err - - to_return.append( - _decode_json_response_dict(response, f"Update shard '{_shard_name}' status") - ) - - if shard_name is None: - return to_return - return cast(list, to_return[0]) - - def _create_complex_properties_from_class(self, schema_class: dict) -> None: - """ - Add cross-references to an already existing class. - - Parameters - ---------- - schema_class : dict - Description of the class that should be added. - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - if "properties" not in schema_class: - # Class has no properties - nothing to do - return - for property_ in schema_class["properties"]: - if _property_is_primitive(property_["dataType"]): - continue - - # Create the property object. All complex dataTypes should be capitalized. - schema_property = { - "dataType": [_capitalize_first_letter(dtype) for dtype in property_["dataType"]], - "name": property_["name"], - } - - for property_field in PROPERTY_KEYS - {"name", "dataType"}: - if property_field in property_: - schema_property[property_field] = property_[property_field] - - path = "/schema/" + _capitalize_first_letter(schema_class["class"]) + "/properties" - try: - response = self._connection.post(path=path, weaviate_object=schema_property) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Property may not have been created properly." - ) from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Add properties to classes", response) - - def _create_complex_properties_from_classes(self, schema_classes_list: list) -> None: - """ - Add cross-references to already existing classes. - - Parameters - ---------- - schema_classes_list : list - A list of classes as they are found in a schema JSON description. - """ - - for schema_class in schema_classes_list: - self._create_complex_properties_from_class(schema_class) - - def _create_class_with_primitives(self, weaviate_class: dict) -> None: - """ - Create class with only primitives. - - Parameters - ---------- - weaviate_class : dict - A single Weaviate formatted class - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - # Create the class - schema_class = { - "class": _capitalize_first_letter(weaviate_class["class"]), - "properties": [], - } - - for class_field in CLASS_KEYS - {"class", "properties"}: - if class_field in weaviate_class: - schema_class[class_field] = weaviate_class[class_field] - - if "properties" in weaviate_class: - schema_class["properties"] = _get_primitive_properties(weaviate_class["properties"]) - - # Add the item - try: - response = self._connection.post(path="/schema", weaviate_object=schema_class) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Class may not have been created properly.") from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Create class", response) - - def _create_classes_with_primitives(self, schema_classes_list: list) -> None: - """ - Create all the classes in the list and primitive properties. - This function does not create references, - to avoid references to classes that do not yet exist. - - Parameters - ---------- - schema_classes_list : list - A list of classes as they are found in a schema JSON description. - """ - - for weaviate_class in schema_classes_list: - self._create_class_with_primitives(weaviate_class) - - def add_class_tenants(self, class_name: str, tenants: List[Tenant]) -> None: - """ - Add class's tenants in Weaviate. - - Parameters - ---------- - class_name : str - The class for which we add tenants. - tenants : List[Tenant] - List of Tenants. - - Examples - -------- - >>> tenants = [ Tenant(name="Tenant1"), Tenant(name="Tenant2") ] - >>> client.schema.add_class_tenants("class_name", tenants) - - Raises - ------ - TypeError - If 'tenants' has not the correct type. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - - loaded_tenants = [tenant._to_weaviate_object() for tenant in tenants] - - path = f"/schema/{_capitalize_first_letter(class_name)}/tenants" - try: - response = self._connection.post(path=path, weaviate_object=loaded_tenants) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Classes tenants may not have been added properly." - ) from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Add classes tenants", response) - - def remove_class_tenants(self, class_name: str, tenants: List[str]) -> None: - """ - Remove class's tenants in Weaviate. - - Parameters - ---------- - class_name : str - The class for which we remove tenants. - tenants : List[str] - List of tenant names to remove from the given class. - - Examples - -------- - >>> client.schema.remove_class_tenants("class_name", ["Tenant1", "Tenant2"]) - - Raises - ------ - TypeError - If 'tenants' has not the correct type. - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - path = f"/schema/{_capitalize_first_letter(class_name)}/tenants" - try: - response = self._connection.delete(path=path, weaviate_object=tenants) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError( - "Classes tenants may not have been deleted." - ) from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Delete classes tenants", response) - - def get_class_tenants(self, class_name: str) -> List[Tenant]: - """Get class's tenants in Weaviate. - - Parameters - ---------- - class_name : str - The class for which we get tenants. - - Examples - -------- - >>> client.schema.get_class_tenants("class_name") - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - path = f"/schema/{_capitalize_first_letter(class_name)}/tenants" - try: - response = self._connection.get(path=path) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Could not get class tenants.") from conn_err - - tenant_resp = _decode_json_response_list(response, "Get class tenants") - assert tenant_resp is not None - return [Tenant._from_weaviate_object(tenant) for tenant in tenant_resp] - - def update_class_tenants(self, class_name: str, tenants: List[Tenant]) -> None: - """Update class tenants. - - Use this when you want to move tenants from one activity state to another. - - Parameters - ---------- - class_name : str - The class for which we update tenants. - tenants : List[Tenant] - List of Tenants. - - Examples - -------- - >>> client.schema.add_class_tenants( - "class_name", - [ - Tenant(activity_status=TenantActivityStatus.HOT, name="Tenant1")), - Tenant(activity_status=TenantActivityStatus.COLD, name="Tenant2")) - Tenant(name="Tenant3") - ] - ) - >>> client.schema.update_class_tenants( - "class_name", - [ - Tenant(activity_status=TenantActivityStatus.COLD, name="Tenant1")), - Tenant(activity_status=TenantActivityStatus.HOT, name="Tenant2")) - ] - ) - >>> client.schema.get_class_tenants("class_name") - [ - Tenant(activity_status=TenantActivityStatus.COLD, name="Tenant1")), - Tenant(activity_status=TenantActivityStatus.HOT, name="Tenant2")), - Tenant(activity_status=TenantActivityStatus.HOT, name="Tenant3")) - ] - - - Raises - ------ - requests.ConnectionError - If the network connection to Weaviate fails. - weaviate.UnexpectedStatusCodeException - If Weaviate reports a non-OK status. - """ - path = f"/schema/{_capitalize_first_letter(class_name)}/tenants" - loaded_tenants = [tenant._to_weaviate_object() for tenant in tenants] - try: - response = self._connection.put(path=path, weaviate_object=loaded_tenants) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Could not update class tenants.") from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Update classes tenants", response) - - -def _property_is_primitive(data_type_list: list) -> bool: - """ - Check if the property is primitive. - - Parameters - ---------- - data_type_list : list - Data types to be checked if are primitive. - - Returns - ------- - bool - True if it only consists of primitive data types, - False otherwise. - """ - - if len(set(data_type_list) - _PRIMITIVE_WEAVIATE_TYPES_SET) == 0: - return True - return False - - -def _get_primitive_properties(properties_list: list) -> list: - """ - Filter the list of properties for only primitive properties. - - Parameters - ---------- - properties_list : list - A list of properties to extract the primitive properties. - - Returns - ------- - list - A list of properties containing only primitives. - """ - - primitive_properties = [] - for property_ in properties_list: - if not _property_is_primitive(property_["dataType"]): - # property is complex and therefore will be ignored - continue - primitive_properties.append(property_) - return primitive_properties - - -def _update_nested_dict(dict_1: dict, dict_2: dict) -> dict: - """ - Update `dict_1` with elements from `dict_2` in a nested manner. - If a value of a key is a dict, it is going to be updated and not replaced by the whole dict. - - Parameters - ---------- - dict_1 : dict - The dictionary to be updated. - dict_2 : dict - The dictionary that contains values to be updated. - - Returns - ------- - dict - The updated `dict_1`. - """ - for key, value in dict_2.items(): - if key not in dict_1: - dict_1[key] = value - continue - if isinstance(value, dict): - _update_nested_dict(dict_1[key], value) - else: - dict_1.update({key: value}) - return dict_1 diff --git a/weaviate/schema/properties/__init__.py b/weaviate/schema/properties/__init__.py deleted file mode 100644 index e4fcf0316..000000000 --- a/weaviate/schema/properties/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Module used to manipulate schema properties. -""" - -__all__ = ["Property"] - -from .crud_properties import Property diff --git a/weaviate/schema/properties/crud_properties.py b/weaviate/schema/properties/crud_properties.py deleted file mode 100644 index 2d5f488fb..000000000 --- a/weaviate/schema/properties/crud_properties.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Property class definition. -""" - -from requests.exceptions import ConnectionError as RequestsConnectionError - -from weaviate.connect import Connection -from weaviate.exceptions import UnexpectedStatusCodeException -from weaviate.util import _get_dict_from_object, _capitalize_first_letter - - -class Property: - """ - Property class used to create object properties. - """ - - def __init__(self, connection: Connection): - """ - Initialize a Property class instance. - - Parameters - ---------- - connection : weaviate.connect.Connection - Connection object to an active and running weaviate instance. - """ - - self._connection = connection - - def create(self, schema_class_name: str, schema_property: dict) -> None: - """ - Create a class property. - - Parameters - ---------- - schema_class_name : str - The name of the class in the schema to which the property - should be added. - schema_property : dict - The property that should be added. - - Examples - -------- - >>> property_age = { - ... "dataType": [ - ... "int" - ... ], - ... "description": "The Author's age", - ... "name": "age" - ... } - >>> client.schema.property.create('Author', property_age) - - Raises - ------ - TypeError - If 'schema_class_name' is of wrong type. - weaviate.exceptions.UnexpectedStatusCodeException - If weaviate reports a none OK status. - requests.ConnectionError - If the network connection to weaviate fails. - weaviate.SchemaValidationException - If the 'schema_property' is not valid. - """ - - if not isinstance(schema_class_name, str): - raise TypeError(f"Class name must be of type str but is {type(schema_class_name)}") - - loaded_schema_property = _get_dict_from_object(schema_property) - - schema_class_name = _capitalize_first_letter(schema_class_name) - - path = f"/schema/{schema_class_name}/properties" - try: - response = self._connection.post(path=path, weaviate_object=loaded_schema_property) - except RequestsConnectionError as conn_err: - raise RequestsConnectionError("Property was created properly.") from conn_err - if response.status_code != 200: - raise UnexpectedStatusCodeException("Add property to class", response) diff --git a/weaviate/util.py b/weaviate/util.py index 12db74ebc..47b40ba40 100644 --- a/weaviate/util.py +++ b/weaviate/util.py @@ -5,7 +5,6 @@ import base64 import datetime import io -import json import os import re import uuid as uuid_lib @@ -184,102 +183,6 @@ def file_decoder_b64(encoded_file: str) -> bytes: return base64.b64decode(encoded_file.encode("utf-8")) -def generate_local_beacon( - to_uuid: Union[str, uuid_lib.UUID], - class_name: Optional[str] = None, -) -> dict: - """ - Generates a beacon with the given uuid and class name (only for Weaviate >= 1.14.0). - - Parameters - ---------- - to_uuid : str or uuid.UUID - The UUID for which to create a local beacon. - class_name : Optional[str], optional - The class name of the `to_uuid` object. Used with Weaviate >= 1.14.0. - For Weaviate < 1.14.0 use None value. - - Returns - ------- - dict - The local beacon. - - Raises - ------ - TypeError - If 'to_uuid' is not of type str. - ValueError - If the 'to_uuid' is not valid. - """ - - if isinstance(to_uuid, str): - try: - uuid = str(uuid_lib.UUID(to_uuid)) - except ValueError: - raise ValueError("Uuid does not have the proper form") from None - elif isinstance(to_uuid, uuid_lib.UUID): - uuid = str(to_uuid) - else: - raise TypeError("Expected to_object_uuid of type str or uuid.UUID") - - if class_name is None: - return {"beacon": f"weaviate://localhost/{uuid}"} # noqa: E231 - return { - "beacon": f"weaviate://localhost/{_capitalize_first_letter(class_name)}/{uuid}" # noqa: E231 - } - - -def _get_dict_from_object(object_: Union[str, dict]) -> dict: - """ - Takes an object that should describe a dict - e.g. a schema or an object and tries to retrieve the dict. - - Parameters - ---------- - object_ : str or dict - The object from which to retrieve the dict. - Can be a python dict, or the path to a json file or a url of a json file. - - Returns - ------- - dict - The object as a dict. - - Raises - ------ - TypeError - If 'object_' is neither a string nor a dict. - ValueError - If no dict can be retrieved from object. - """ - - # check if objects files is url - if object_ is None: - raise TypeError("argument is None") - - if isinstance(object_, dict): - # Object is already a dict - return object_ - if isinstance(object_, str): - if validators.url(object_): - # Object is URL - response = requests.get(object_) - if response.status_code == 200: - return cast(dict, response.json()) - raise ValueError("Could not download file " + object_) - - if not os.path.isfile(object_): - # Object is neither file nor URL - raise ValueError("No file found at location " + object_) - # Object is file - with open(object_, "r") as file: - return cast(dict, json.load(file)) - raise TypeError( - "Argument is not of the supported types. Supported types are " - "url or file path as string or schema as dict." - ) - - def is_weaviate_object_url(url: str) -> bool: """ Checks if the input follows a normal Weaviate 'beacon' like this: diff --git a/weaviate/warnings.py b/weaviate/warnings.py index 5956f112a..99e949645 100644 --- a/weaviate/warnings.py +++ b/weaviate/warnings.py @@ -181,21 +181,6 @@ def palm_to_google_gen() -> None: stacklevel=1, ) - @staticmethod - def weaviate_v3_client_is_deprecated() -> None: - warnings.warn( - message="""Dep016: Python client v3 `weaviate.Client(...)` connections and methods are deprecated and will - be removed by 2024-11-30. - - Upgrade your code to use Python client v4 `weaviate.WeaviateClient` connections and methods. - - For Python Client v4 usage, see: https://weaviate.io/developers/weaviate/client-libraries/python - - For code migration, see: https://weaviate.io/developers/weaviate/client-libraries/python/v3_v4_migration - - If you have to use v3 code, install the v3 client and pin the v3 dependency in your requirements file: `weaviate-client>=3.26.7;<4.0.0`""", - category=DeprecationWarning, - stacklevel=1, - ) - @staticmethod def vector_index_config_in_config_update() -> None: warnings.warn(