diff --git a/.flake8 b/.flake8 index 29aabc67c..877c7cb55 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] max-line-length = 100 exclude = .git, venv, .venv, .pytest_cache, dist, .idea, docs/conf.py, weaviate/collections/orm.py, weaviate/collections/classes/orm.py, weaviate/proto/**/*.py -ignore = D100, D104, D105, D107, E203, E266, E501, E731, W503 +ignore = D100, D104, D105, D107, E203, E266, E501, E704, E731, W503 per-file-ignores = weaviate/cluster/types.py:A005 weaviate/collections/classes/types.py:A005 @@ -14,4 +14,5 @@ per-file-ignores = # D104: Missing docstring in public package # D105: Missing docstring in magic method # D107: Missing docstring in __init__ +# E704: Multiple statements on one line (def) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 6a45bd895..9b0225a2e 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -58,7 +58,7 @@ jobs: cache: 'pip' # caching pip dependencies - run: pip install -r requirements-devel.txt - name: Run mypy - run: mypy --warn-unused-ignores --python-version ${{matrix.version}} ${{ matrix.folder }} + run: mypy --config-file ./pyproject.toml --warn-unused-ignores --python-version ${{matrix.version}} ${{ matrix.folder }} - uses: jakebailey/pyright-action@v2 with: version: 1.1.347 @@ -321,17 +321,17 @@ jobs: $WEAVIATE_126 ] steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Download build artifact to append to release uses: actions/download-artifact@v4 with: name: weaviate-python-client-wheel - run: | pip install weaviate_client-*.whl - pip install pytest pytest-asyncio pytest-benchmark pytest-profiling grpcio grpcio-tools pytest-xdist - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 + pip install -r requirements-devel.txt # install test dependencies - name: free space run: sudo rm -rf /usr/local/lib/android - run: rm -r weaviate diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1be85dfd5..9c9510262 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: no-commit-to-branch - id: trailing-whitespace @@ -19,7 +19,7 @@ repos: - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.1.0 hooks: - id: flake8 name: linting diff --git a/docs/changelog.rst b/docs/changelog.rst index 5702c17ca..b87c5e466 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,23 @@ Changelog ========= +Version 4.6.6 +-------------- + +This patch version includes: + +- Log batch errors +- Only the last 100k successfully added UUIDs are kept in memory to prevent OOM situations. +- Fix tenant creation with string input + +In the v3 copy that is part of v4: + +- Fixes GraphQL query injection vulnerability caused by incorrect escaping of backslashes in plain text input builder methods. Many thanks to `@adamleko `_, `@bismuthsalamander `_, and `@tardigrade-9 `_ for their help in fixing this issue +- Fixes batch retry with tenants + + + + Version 4.6.5 -------------- @@ -529,6 +546,19 @@ This beta version includes: - No more builder methods or raw dictionaries - Join the discussion and contribute your feedback `here `_ +Version 3.26.5 +-------------- +This patch version includes + +- Fixes GraphQL query injection vulnerability caused by incorrect escaping of backslashes in plain text input builder methods +- Many thanks to `@adamleko `_, `@bismuthsalamander `_, and `@tardigrade-9 `_ for their help in fixing this issue + +Version 3.26.4 +-------------- +This patch version includes + +- Fixes batch retry with tenants + Version 3.26.2 -------------- This patch version includes diff --git a/integration/test_collection.py b/integration/test_collection.py index 29efa942a..ad236e494 100644 --- a/integration/test_collection.py +++ b/integration/test_collection.py @@ -7,6 +7,7 @@ import pytest +import weaviate.classes as wvc from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name from integration.constants import WEAVIATE_LOGO_OLD_ENCODED, WEAVIATE_LOGO_NEW_ENCODED from weaviate.collections.classes.batch import ErrorObject @@ -51,8 +52,6 @@ ) from weaviate.types import UUID, UUIDS -import weaviate.classes as wvc - UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9") UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a") UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75") @@ -863,107 +862,6 @@ def test_query_properties(collection_factory: CollectionFactory) -> None: assert len(objects) == 0 -def test_near_vector(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana = collection.data.insert({"Name": "Banana"}) - collection.data.insert({"Name": "Fruit"}) - collection.data.insert({"Name": "car"}) - collection.data.insert({"Name": "Mountain"}) - - banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - full_objects = collection.query.near_vector( - banana.vector["default"], return_metadata=MetadataQuery(distance=True, certainty=True) - ).objects - assert len(full_objects) == 4 - - objects_distance = collection.query.near_vector( - banana.vector["default"], distance=full_objects[2].metadata.distance - ).objects - assert len(objects_distance) == 3 - - objects_distance = collection.query.near_vector( - banana.vector["default"], certainty=full_objects[2].metadata.certainty - ).objects - assert len(objects_distance) == 3 - - -def test_near_vector_limit(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana = collection.data.insert({"Name": "Banana"}) - collection.data.insert({"Name": "Fruit"}) - collection.data.insert({"Name": "car"}) - collection.data.insert({"Name": "Mountain"}) - - banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - objs = collection.query.near_vector(banana.vector["default"], limit=2).objects - assert len(objs) == 2 - - -def test_near_vector_offset(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana = collection.data.insert({"Name": "Banana"}) - uuid_fruit = collection.data.insert({"Name": "Fruit"}) - collection.data.insert({"Name": "car"}) - collection.data.insert({"Name": "Mountain"}) - - banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - objs = collection.query.near_vector(banana.vector["default"], offset=1).objects - assert len(objs) == 3 - assert objs[0].uuid == uuid_fruit - - -def test_near_vector_group_by_argument(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[ - Property(name="Name", data_type=DataType.TEXT), - Property(name="Count", data_type=DataType.INT), - ], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana1 = collection.data.insert({"Name": "Banana", "Count": 51}) - collection.data.insert({"Name": "Banana", "Count": 72}) - collection.data.insert({"Name": "car", "Count": 12}) - collection.data.insert({"Name": "Mountain", "Count": 1}) - - banana1 = collection.query.fetch_object_by_id(uuid_banana1, include_vector=True) - - ret = collection.query.near_vector( - banana1.vector["default"], - group_by=GroupBy( - prop="name", - number_of_groups=4, - objects_per_group=10, - ), - return_metadata=MetadataQuery(distance=True, certainty=True), - ) - - assert len(ret.objects) == 4 - assert ret.objects[0].belongs_to_group == "Banana" - assert ret.objects[1].belongs_to_group == "Banana" - assert ret.objects[2].belongs_to_group == "car" - assert ret.objects[3].belongs_to_group == "Mountain" - - def test_near_object(collection_factory: CollectionFactory) -> None: collection = collection_factory( properties=[Property(name="Name", data_type=DataType.TEXT)], diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index 8507b092b..356061595 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -1,6 +1,7 @@ from typing import Generator -import pytest +import pytest as pytest +from _pytest.fixtures import SubRequest import weaviate from integration.conftest import OpenAICollection, CollectionFactory @@ -898,3 +899,134 @@ def test_dynamic_collection(collection_factory: CollectionFactory) -> None: assert config.vector_index_config.flat.vector_cache_max_objects == 9876 assert isinstance(config.vector_index_config.flat.quantizer, _BQConfig) assert config.vector_index_config.flat.quantizer.rescore_limit == 11 + + +def test_config_unknown_module(request: SubRequest) -> None: + with weaviate.connect_to_local() as client: + collection_name = _sanitize_collection_name(request.node.name) + client.collections.delete(name=collection_name) + collection = client.collections.create_from_dict( + { + "class": collection_name, + "vectorizer": "none", + "moduleConfig": {"generative-dummy": {}, "reranker-dummy": {}}, + "properties": [ + {"name": "prop", "dataType": ["text"]}, + ], + } + ) + config = collection.config.get() + assert config.generative_config is not None + assert isinstance(config.generative_config.generative, str) + assert config.generative_config.generative == "generative-dummy" + + assert config.reranker_config is not None + assert isinstance(config.reranker_config.reranker, str) + assert config.reranker_config.reranker == "reranker-dummy" + + client.collections.delete(name=collection_name) + + collection2 = client.collections.create_from_config(config) + config2 = collection2.config.get() + assert config == config2 + assert config2.generative_config is not None + assert isinstance(config2.generative_config.generative, str) + assert config2.generative_config.generative == "generative-dummy" + + assert config2.reranker_config is not None + assert isinstance(config2.reranker_config.reranker, str) + assert config2.reranker_config.reranker == "reranker-dummy" + + client.collections.delete(name=collection_name) + + +def test_create_custom_module(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + generative_config=Configure.Generative.custom( + "generative-anyscale", module_config={"temperature": 0.5} + ) + ) + config = collection.config.get() + + collection2 = collection_factory( + generative_config=Configure.Generative.anyscale(temperature=0.5) + ) + config2 = collection2.config.get() + + assert config.generative_config == config2.generative_config + assert isinstance(config.generative_config.generative, str) + assert config.generative_config.generative == "generative-anyscale" + assert config.generative_config.model == {"temperature": 0.5} + + +def test_create_custom_reranker(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + reranker_config=Configure.Reranker.custom( + "reranker-cohere", module_config={"model": "rerank-english-v2.0"} + ) + ) + config = collection.config.get() + + collection2 = collection_factory( + reranker_config=Configure.Reranker.cohere(model="rerank-english-v2.0") + ) + config2 = collection2.config.get() + + assert config.reranker_config == config2.reranker_config + assert isinstance(config.reranker_config.reranker, str) + assert config.reranker_config.reranker == "reranker-cohere" + assert config.reranker_config.model == {"model": "rerank-english-v2.0"} + + +def test_create_custom_vectorizer(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.custom( + "text2vec-contextionary", module_config={"vectorizeClassName": False} + ), + ) + config = collection.config.get() + + collection2 = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + config2 = collection2.config.get() + + assert config.vectorizer_config == config2.vectorizer_config + assert isinstance(config.vectorizer_config.vectorizer, str) + assert config.vectorizer_config.vectorizer == "text2vec-contextionary" + assert not config.vectorizer_config.vectorize_collection_name + + +def test_create_custom_vectorizer_named(collection_factory: CollectionFactory) -> None: + collection_dummy = collection_factory("dummy") + if collection_dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + pytest.skip("Named index is not supported in Weaviate versions lower than 1.24.0") + + collection = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=[ + Configure.NamedVectors.custom( + "name", + module_name="text2vec-contextionary", + module_config={"vectorizeClassName": False}, + ) + ], + ) + config = collection.config.get() + + collection2 = collection_factory( + properties=[Property(name="text", data_type=DataType.TEXT)], + vectorizer_config=[ + Configure.NamedVectors.text2vec_contextionary("name", vectorize_collection_name=False) + ], + ) + config2 = collection2.config.get() + + assert config.vector_config == config2.vector_config + assert len(config.vector_config) == 1 + assert config.vector_config["name"].vectorizer.vectorizer == "text2vec-contextionary" + assert config.vector_config["name"].vectorizer.model == {"vectorizeClassName": False} diff --git a/integration/test_collection_near_vector.py b/integration/test_collection_near_vector.py new file mode 100644 index 000000000..1b866c12e --- /dev/null +++ b/integration/test_collection_near_vector.py @@ -0,0 +1,177 @@ +import uuid +from typing import Any + +import numpy as np +import pandas as pd +import polars as pl +import pytest + +from integration.conftest import CollectionFactory +from weaviate.collections.classes.config import ( + Configure, + DataType, + Property, +) +from weaviate.collections.classes.grpc import ( + GroupBy, + MetadataQuery, +) + +UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9") +UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a") +UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75") + + +def test_near_vector(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana = collection.data.insert({"Name": "Banana"}) + collection.data.insert({"Name": "Fruit"}) + collection.data.insert({"Name": "car"}) + collection.data.insert({"Name": "Mountain"}) + + banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + full_objects = collection.query.near_vector( + banana.vector["default"], return_metadata=MetadataQuery(distance=True, certainty=True) + ).objects + assert len(full_objects) == 4 + + objects_distance = collection.query.near_vector( + banana.vector["default"], distance=full_objects[2].metadata.distance + ).objects + assert len(objects_distance) == 3 + + objects_distance = collection.query.near_vector( + banana.vector["default"], certainty=full_objects[2].metadata.certainty + ).objects + assert len(objects_distance) == 3 + + +def test_near_vector_limit(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana = collection.data.insert({"Name": "Banana"}) + collection.data.insert({"Name": "Fruit"}) + collection.data.insert({"Name": "car"}) + collection.data.insert({"Name": "Mountain"}) + + banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + objs = collection.query.near_vector(banana.vector["default"], limit=2).objects + assert len(objs) == 2 + + +def test_near_vector_offset(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana = collection.data.insert({"Name": "Banana"}) + uuid_fruit = collection.data.insert({"Name": "Fruit"}) + collection.data.insert({"Name": "car"}) + collection.data.insert({"Name": "Mountain"}) + + banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + objs = collection.query.near_vector(banana.vector["default"], offset=1).objects + assert len(objs) == 3 + assert objs[0].uuid == uuid_fruit + + +def test_near_vector_group_by_argument(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[ + Property(name="Name", data_type=DataType.TEXT), + Property(name="Count", data_type=DataType.INT), + ], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana1 = collection.data.insert({"Name": "Banana", "Count": 51}) + collection.data.insert({"Name": "Banana", "Count": 72}) + collection.data.insert({"Name": "car", "Count": 12}) + collection.data.insert({"Name": "Mountain", "Count": 1}) + + banana1 = collection.query.fetch_object_by_id(uuid_banana1, include_vector=True) + + ret = collection.query.near_vector( + banana1.vector["default"], + group_by=GroupBy( + prop="name", + number_of_groups=4, + objects_per_group=10, + ), + return_metadata=MetadataQuery(distance=True, certainty=True), + ) + + assert len(ret.objects) == 4 + assert ret.objects[0].belongs_to_group == "Banana" + assert ret.objects[1].belongs_to_group == "Banana" + assert ret.objects[2].belongs_to_group == "car" + assert ret.objects[3].belongs_to_group == "Mountain" + + +@pytest.mark.parametrize( + "near_vector", [[1, 0], [1.0, 0.0], np.array([1, 0]), pl.Series([1, 0]), pd.Series([1, 0])] +) +def test_near_vector_with_other_input( + collection_factory: CollectionFactory, near_vector: Any +) -> None: + collection = collection_factory(vectorizer_config=Configure.Vectorizer.none()) + + uuid1 = collection.data.insert({}, vector=[1, 0]) + collection.data.insert({}, vector=[0, 1]) + + ret = collection.query.near_vector( + near_vector, + distance=0.1, + ) + assert len(ret.objects) == 1 + assert ret.objects[0].uuid == uuid1 + + +@pytest.mark.parametrize( + "near_vector", + [ + {"first": [1, 0], "second": [1, 0, 0]}, + {"first": np.array([1, 0]), "second": [1, 0, 0]}, + {"first": pl.Series([1, 0]), "second": [1, 0, 0]}, + {"first": pd.Series([1, 0]), "second": [1, 0, 0]}, + [np.array([1, 0]), [1, 0, 0]], + [pl.Series([1, 0]), [1, 0, 0]], + [pd.Series([1, 0]), [1, 0, 0]], + {"first": [1.0, 0.0], "second": [1.0, 0.0, 0.0]}, + ], +) +def test_near_vector_with_named_vector_other_input( + collection_factory: CollectionFactory, near_vector: Any +) -> None: + dummy = collection_factory("dummy") + if dummy._connection._weaviate_version.is_lower_than(1, 26, 0): + pytest.skip("Named vectors are supported in versions higher than 1.26.0") + + collection = collection_factory( + vectorizer_config=[ + Configure.NamedVectors.none("first"), + Configure.NamedVectors.none("second"), + ] + ) + + uuid1 = collection.data.insert({}, vector={"first": [1, 0], "second": [1, 0, 0]}) + collection.data.insert({}, vector={"first": [0, 1], "second": [0, 0, 1]}) + + ret = collection.query.near_vector(near_vector, distance=0.1, target_vector=["first", "second"]) + assert len(ret.objects) == 1 + assert ret.objects[0].uuid == uuid1 diff --git a/integration/test_tenants.py b/integration/test_tenants.py index f244c980d..8738c7c6b 100644 --- a/integration/test_tenants.py +++ b/integration/test_tenants.py @@ -14,10 +14,10 @@ ) from weaviate.collections.classes.tenants import ( Tenant, - TenantInput, + TenantCreate, TenantActivityStatus, ) -from weaviate.collections.tenants import TenantInputType +from weaviate.collections.tenants import TenantCreateInputType from weaviate.exceptions import WeaviateInvalidInputError, WeaviateUnsupportedFeatureError @@ -325,14 +325,15 @@ def test_autotenant_toggling(collection_factory: CollectionFactory) -> None: [ "tenant", Tenant(name="tenant"), - TenantInput(name="tenant"), + TenantCreate(name="tenant"), ["tenant"], [Tenant(name="tenant")], - [TenantInput(name="tenant")], + [TenantCreate(name="tenant")], ], ) def test_tenants_create( - collection_factory: CollectionFactory, tenants: Union[TenantInputType, List[TenantInputType]] + collection_factory: CollectionFactory, + tenants: Union[TenantCreateInputType, List[TenantCreateInputType]], ) -> None: collection = collection_factory( vectorizer_config=Configure.Vectorizer.none(), @@ -350,14 +351,12 @@ def test_tenants_create( [ "tenant", Tenant(name="tenant"), - TenantInput(name="tenant"), ["tenant"], [Tenant(name="tenant")], - [TenantInput(name="tenant")], ], ) def test_tenants_remove( - collection_factory: CollectionFactory, tenants: Union[TenantInputType, List[TenantInputType]] + collection_factory: CollectionFactory, tenants: Union[str, Tenant, List[Union[str, Tenant]]] ) -> None: collection = collection_factory( vectorizer_config=Configure.Vectorizer.none(), @@ -377,10 +376,12 @@ def test_tenants_remove( Tenant(name="1", activity_status=TenantActivityStatus.FREEZING), Tenant(name="1", activity_status=TenantActivityStatus.UNFREEZING), Tenant(name="1", activity_status=TenantActivityStatus.UNFROZEN), + Tenant(name="1", activity_status=TenantActivityStatus.FROZEN), [ Tenant(name="1", activity_status=TenantActivityStatus.FREEZING), Tenant(name="2", activity_status=TenantActivityStatus.UNFREEZING), Tenant(name="3", activity_status=TenantActivityStatus.UNFROZEN), + Tenant(name="4", activity_status=TenantActivityStatus.FROZEN), ], ], ) diff --git a/integration_v3/test_injection.py b/integration_v3/test_injection.py new file mode 100644 index 000000000..21eacf5d5 --- /dev/null +++ b/integration_v3/test_injection.py @@ -0,0 +1,63 @@ +import pytest +import weaviate +import requests +import json + + +def injection_template(n: int) -> str: + return "Liver" + ("\\" * n) + '"}}){{answer}}}}{payload}#' + + +@pytest.mark.parametrize("n_backslashes", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +def test_gql_injection(n_backslashes: int) -> None: + client = weaviate.Client(url="http://localhost:8080") + client.schema.delete_class("Question") + client.schema.delete_class("Hacked") + class_obj = { + "class": "Question", + "vectorizer": "text2vec-contextionary", + "properties": [ + {"name": "answer", "dataType": ["string"], "tokenization": "field"}, + {"name": "question", "dataType": ["string"]}, + {"name": "category", "dataType": ["string"]}, + ], + } + + class_obj2 = { + "class": "Hacked", + "vectorizer": "text2vec-contextionary", + "properties": [ + {"name": "answer", "dataType": ["string"]}, + {"name": "question", "dataType": ["string"]}, + {"name": "category", "dataType": ["string"]}, + ], + } + client.schema.create_class(class_obj) + client.schema.create_class(class_obj2) + + resp = requests.get( + "https://raw.githubusercontent.com/weaviate-tutorials/quickstart/main/data/jeopardy_tiny.json" + ) + data = json.loads(resp.text) + + client.batch.configure(batch_size=100) + with client.batch as batch: + for _, d in enumerate(data): + properties = { + "answer": d["Answer"], + "question": d["Question"], + "category": d["Category"], + } + batch.add_data_object(data_object=properties, class_name="Question") + batch.add_data_object(data_object=properties, class_name="Hacked") + + injection_payload = client.query.get("Hacked", ["answer"]).build() + query = client.query.get("Question", ["question", "answer", "category"]).with_where( + { + "path": ["answer"], + "operator": "NotEqual", + "valueText": injection_template(n_backslashes).format(payload=injection_payload[1:]), + } + ) + res = query.do() + assert "Hacked" not in res["data"]["Get"] diff --git a/requirements-devel.txt b/requirements-devel.txt index 223ce052d..426b7936a 100644 --- a/requirements-devel.txt +++ b/requirements-devel.txt @@ -32,6 +32,7 @@ py-spy>=0.3.14 numpy>=1.24.4,<2.0.0 pandas>=2.0.3,<3.0.0 +pandas-stubs>=2.0.3,<3.0.0 polars>=0.20.26,<0.21.0 fastapi>=0.111.0,<1.0.0 diff --git a/run-mypy.sh b/run-mypy.sh index 885cd5eff..1fcf5ebd6 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -1,8 +1,10 @@ #!/usr/bin/env bash +python3 -m venv venv +source venv/bin/activate pip install -r requirements-devel.txt >/dev/null 2>&1 echo "Static checking ./weaviate:" mypy --config-file ./pyproject.toml ./weaviate echo "Static checking ./integration:" -mypy --config-file ./pyproject.toml ./integration \ No newline at end of file +mypy --config-file ./pyproject.toml --warn-unused-ignores ./weaviate diff --git a/test/collection/test_config.py b/test/collection/test_config.py index 8d6dd1ca5..7ba07248e 100644 --- a/test/collection/test_config.py +++ b/test/collection/test_config.py @@ -752,6 +752,26 @@ def test_config_with_vectorizer_and_properties( } }, ), + ( + Configure.Generative.anthropic( + model="model", + max_tokens=100, + stop_sequences=["stop"], + temperature=0.5, + top_k=10, + top_p=0.5, + ), + { + "generative-anthropic": { + "model": "model", + "maxTokens": 100, + "stopSequences": ["stop"], + "temperature": 0.5, + "topK": 10, + "topP": 0.5, + } + }, + ), ] diff --git a/test/collection/test_validator.py b/test/collection/test_validator.py new file mode 100644 index 000000000..66cfaafef --- /dev/null +++ b/test/collection/test_validator.py @@ -0,0 +1,38 @@ +from typing import Any, List + +import numpy as np +import pandas as pd +import polars as pl +import pytest + +from weaviate.exceptions import WeaviateInvalidInputError +from weaviate.validator import _validate_input, _ValidateArgument, _ExtraTypes + + +@pytest.mark.parametrize( + "inputs,expected,error", + [ + (1, [int], False), + (1.0, [int], True), + ([1, 1], [List], False), + (np.array([1, 2, 3]), [_ExtraTypes.NUMPY], False), + (np.array([1, 2, 3]), [_ExtraTypes.NUMPY, List], False), + (np.array([1, 2, 3]), [List], True), + ([1, 1], [List, _ExtraTypes.NUMPY], False), + (pd.array([1, 1]), [_ExtraTypes.PANDAS, List], False), + (pd.Series([1, 1]), [_ExtraTypes.PANDAS, List], False), + (pl.Series([1, 1]), [_ExtraTypes.POLARS, List], False), + ( + pl.Series([1, 1]), + [_ExtraTypes.POLARS, _ExtraTypes.PANDAS, _ExtraTypes.NUMPY, List], + False, + ), + (pl.Series([1, 1]), [_ExtraTypes.PANDAS, _ExtraTypes.NUMPY, List], True), + ], +) +def test_validator(inputs: Any, expected: List[Any], error: bool) -> None: + if error: + with pytest.raises(WeaviateInvalidInputError): + _validate_input(_ValidateArgument(expected=expected, name="test", value=inputs)) + else: + _validate_input(_ValidateArgument(expected=expected, name="test", value=inputs)) diff --git a/test/test_embedded.py b/test/test_embedded.py index bcec54cf9..bd0dbe13e 100644 --- a/test/test_embedded.py +++ b/test/test_embedded.py @@ -118,7 +118,7 @@ def test_embedded_end_to_end(options: EmbeddedDB, tmp_path): embedded_db.ensure_running() assert embedded_db.is_listening() is True - with patch("builtins.print") as mocked_print: + with patch("weaviate.logger.logger.info") as mocked_print: embedded_db.start() mocked_print.assert_called_once_with( f"embedded weaviate is already listening on port {options.port}" diff --git a/test/test_util.py b/test/test_util.py index c14d47700..25b8ccc04 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -22,6 +22,7 @@ is_weaviate_too_old, is_weaviate_client_too_old, MINIMUM_NO_WARNING_VERSION, + _sanitize_str, ) schema_set = { @@ -575,3 +576,19 @@ def test_is_weaviate_too_old(version: str, too_old: bool): ) def test_is_weaviate_client_too_old(current_version: str, latest_version: str, too_old: bool): assert is_weaviate_client_too_old(current_version, latest_version) is too_old + + +@pytest.mark.parametrize( + "in_str, out_str", + [ + ('"', '\\"'), + ('"', '\\"'), + ('\\"', '\\"'), + ('\\"', '\\"'), + ('\\\\"', '\\\\\\"'), + ('\\\\"', '\\\\\\"'), + ('\\\\\\"', '\\\\\\"'), + ], +) +def test_sanitize_str(in_str: str, out_str: str) -> None: + assert _sanitize_str(in_str) == f'"{out_str}"' diff --git a/weaviate/classes/tenants.py b/weaviate/classes/tenants.py index e9c0309fc..884a07124 100644 --- a/weaviate/classes/tenants.py +++ b/weaviate/classes/tenants.py @@ -1,15 +1,20 @@ from weaviate.collections.classes.tenants import ( Tenant, - TenantInput, + TenantCreate, + TenantUpdate, TenantActivityStatus, - TenantActivityStatusInput, + TenantCreateActivityStatus, + TenantUpdateActivityStatus, ) -from weaviate.collections.tenants import TenantInputType +from weaviate.collections.tenants import TenantCreateInputType, TenantUpdateInputType __all__ = [ "Tenant", - "TenantInput", + "TenantCreate", + "TenantUpdate", "TenantActivityStatus", - "TenantActivityStatusInput", - "TenantInputType", + "TenantCreateActivityStatus", + "TenantUpdateActivityStatus", + "TenantCreateInputType", + "TenantUpdateInputType", ] diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 8671c6dc2..f9c2c15ea 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -38,6 +38,7 @@ from weaviate.connect import ConnectionV4 from weaviate.event_loop import _EventLoop from weaviate.exceptions import WeaviateBatchValidationError, EmptyResponseException +from weaviate.logger import logger from weaviate.types import UUID, VECTORS from weaviate.util import _decode_json_response_dict from weaviate.warnings import _Warnings @@ -187,6 +188,8 @@ def __init__( self.__loop = event_loop self.__objs_count = 0 + self.__objs_logs_count = 0 + self.__refs_logs_count = 0 if isinstance(self.__batching_mode, _FixedSizeBatching): self.__recommended_num_objects = self.__batching_mode.batch_size @@ -448,13 +451,19 @@ def __dynamic_batching(self) -> None: async def __send_batch( self, objs: List[_BatchObject], refs: List[_BatchReference], readd_rate_limit: bool ) -> None: - if len(objs) > 0: + if (n_objs := len(objs)) > 0: start = time.time() try: response_obj = await self.__batch_grpc.objects( objects=objs, timeout=DEFAULT_REQUEST_TIMEOUT ) except Exception as e: + logger.warn( + { + "message": "Failed to insert objects in batch. Inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects.", + "error": repr(e), + } + ) errors_obj = { idx: ErrorObject(message=repr(e), object_=obj) for idx, obj in enumerate(objs) } @@ -543,17 +552,29 @@ async def __send_batch( ) self.__uuid_lookup_lock.release() + if (n_obj_errs := len(response_obj.errors)) > 0 and n_obj_errs < 30: + logger.error( + { + "message": f"Failed to send {n_obj_errs} objects in a batch of {n_objs}. Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects.", + } + ) + self.__objs_logs_count += 1 + if self.__objs_logs_count > 30: + logger.error( + { + "message": "There have been more than 30 failed object batches. Further errors will not be logged.", + } + ) self.__results_lock.acquire() self.__results_for_wrapper.results.objs += response_obj self.__results_for_wrapper.failed_objects.extend(response_obj.errors.values()) self.__results_lock.release() self.__took_queue.append(time.time() - start) - if len(refs) > 0: + if (n_refs := len(refs)) > 0: start = time.time() try: response_ref = await self.__batch_rest.references(references=refs) - except Exception as e: errors_ref = { idx: ErrorReference(message=repr(e), reference=ref) @@ -564,6 +585,20 @@ async def __send_batch( errors=errors_ref, has_errors=True, ) + if (n_ref_errs := len(response_ref.errors)) > 0 and n_ref_errs < 30: + logger.error( + { + "message": f"Failed to send {n_ref_errs} references in a batch of {n_refs}. Please inspect client.batch.failed_references or collection.batch.failed_references for the failed references.", + "errors": response_ref.errors, + } + ) + self.__refs_logs_count += 1 + if self.__refs_logs_count > 30: + logger.error( + { + "message": "There have been more than 30 failed reference batches. Further errors will not be logged.", + } + ) self.__results_lock.acquire() self.__results_for_wrapper.results.refs += response_ref self.__results_for_wrapper.failed_references.extend(response_ref.errors.values()) diff --git a/weaviate/collections/batch/batch_wrapper.py b/weaviate/collections/batch/batch_wrapper.py index f91c30f9d..f2dab4535 100644 --- a/weaviate/collections/batch/batch_wrapper.py +++ b/weaviate/collections/batch/batch_wrapper.py @@ -12,6 +12,7 @@ from weaviate.collections.classes.config import ConsistencyLevel from weaviate.connect import ConnectionV4 from weaviate.event_loop import _EventLoopSingleton +from weaviate.logger import logger from weaviate.util import _capitalize_first_letter, _decode_json_response_list @@ -62,7 +63,7 @@ async def is_ready(how_many: int) -> bool: ) return all(all(readiness) for readiness in readinesses) except Exception as e: - print( + logger.warn( f"Error while getting class shards statuses: {e}, trying again with 2**n={2**how_many}s exponential backoff with n={how_many}" ) if how_many_failures == how_many: @@ -73,10 +74,10 @@ async def is_ready(how_many: int) -> bool: count = 0 while not self._event_loop.run_until_complete(is_ready, count): if count % 20 == 0: # print every 5s - print("Waiting for async indexing to finish...") + logger.debug("Waiting for async indexing to finish...") time.sleep(0.25) count += 1 - print("Async indexing finished!") + logger.debug("Async indexing finished!") async def __get_shards_readiness(self, shard: Shard) -> List[bool]: path = f"/schema/{_capitalize_first_letter(shard.collection)}/shards{'' if shard.tenant is None else f'?tenant={shard.tenant}'}" diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 7da6557dc..fa8a0d560 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -22,6 +22,7 @@ _ConfigCreateModel, _ConfigUpdateModel, _QuantizerConfigUpdate, + _EnumLikeStr, ) from weaviate.collections.classes.config_named_vectors import ( _NamedVectorConfigCreate, @@ -29,26 +30,23 @@ _NamedVectors, _NamedVectorsUpdate, ) +from weaviate.collections.classes.config_vector_index import VectorIndexType as VectorIndexTypeAlias from weaviate.collections.classes.config_vector_index import ( _QuantizerConfigCreate, _VectorIndexConfigCreate, + _VectorIndexConfigDynamicCreate, _VectorIndexConfigDynamicUpdate, - _VectorIndexConfigHNSWCreate, _VectorIndexConfigFlatCreate, - _VectorIndexConfigHNSWUpdate, _VectorIndexConfigFlatUpdate, - _VectorIndexConfigDynamicCreate, + _VectorIndexConfigHNSWCreate, + _VectorIndexConfigHNSWUpdate, _VectorIndexConfigSkipCreate, _VectorIndexConfigUpdate, - VectorIndexType as VectorIndexTypeAlias, -) -from weaviate.collections.classes.config_vectorizers import ( - _Vectorizer, - _VectorizerConfigCreate, - CohereModel, - Vectorizers as VectorizersAlias, - VectorDistances as VectorDistancesAlias, ) +from weaviate.collections.classes.config_vectorizers import CohereModel +from weaviate.collections.classes.config_vectorizers import VectorDistances as VectorDistancesAlias +from weaviate.collections.classes.config_vectorizers import Vectorizers as VectorizersAlias +from weaviate.collections.classes.config_vectorizers import _Vectorizer, _VectorizerConfigCreate from weaviate.exceptions import WeaviateInvalidInputError from weaviate.util import _capitalize_first_letter from weaviate.warnings import _Warnings @@ -153,6 +151,8 @@ class GenerativeSearches(str, Enum): See the [docs](https://weaviate.io/developers/weaviate/modules/reader-generator-modules) for more details. Attributes: + `AWS` + Weaviate module backed by AWS Bedrock generative models. `OPENAI` Weaviate module backed by OpenAI and Azure-OpenAI generative models. `COHERE` @@ -161,9 +161,12 @@ class GenerativeSearches(str, Enum): Weaviate module backed by PaLM generative models. `AWS` Weaviate module backed by AWS Bedrock generative models. + `ANTHROPIC` + Weaviate module backed by Anthropic generative models. """ AWS = "generative-aws" + ANTHROPIC = "generative-anthropic" ANYSCALE = "generative-anyscale" COHERE = "generative-cohere" MISTRAL = "generative-mistral" @@ -387,19 +390,26 @@ class _MultiTenancyConfigUpdate(_ConfigUpdateModel): class _GenerativeConfigCreate(_ConfigCreateModel): - generative: GenerativeSearches + generative: Union[GenerativeSearches, _EnumLikeStr] class _GenerativeAnyscale(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.ANYSCALE, frozen=True, exclude=True ) temperature: Optional[float] model: Optional[str] +class _GenerativeCustom(_GenerativeConfigCreate): + module_config: Dict[str, Any] + + def _to_dict(self) -> Dict[str, Any]: + return self.module_config + + class _GenerativeOctoai(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.OCTOAI, frozen=True, exclude=True ) baseURL: Optional[str] @@ -409,7 +419,7 @@ class _GenerativeOctoai(_GenerativeConfigCreate): class _GenerativeMistral(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.MISTRAL, frozen=True, exclude=True ) temperature: Optional[float] @@ -418,7 +428,7 @@ class _GenerativeMistral(_GenerativeConfigCreate): class _GenerativeOllama(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.OLLAMA, frozen=True, exclude=True ) model: Optional[str] @@ -426,7 +436,7 @@ class _GenerativeOllama(_GenerativeConfigCreate): class _GenerativeOpenAIConfigBase(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.OPENAI, frozen=True, exclude=True ) baseURL: Optional[AnyHttpUrl] @@ -453,7 +463,7 @@ class _GenerativeAzureOpenAIConfig(_GenerativeOpenAIConfigBase): class _GenerativeCohereConfig(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.COHERE, frozen=True, exclude=True ) baseURL: Optional[AnyHttpUrl] @@ -472,7 +482,7 @@ def _to_dict(self) -> Dict[str, Any]: class _GenerativePaLMConfig(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.PALM, frozen=True, exclude=True ) apiEndpoint: Optional[str] @@ -485,7 +495,7 @@ class _GenerativePaLMConfig(_GenerativeConfigCreate): class _GenerativeAWSConfig(_GenerativeConfigCreate): - generative: GenerativeSearches = Field( + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.AWS, frozen=True, exclude=True ) region: str @@ -494,27 +504,52 @@ class _GenerativeAWSConfig(_GenerativeConfigCreate): endpoint: Optional[str] +class _GenerativeAnthropicConfig(_GenerativeConfigCreate): + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( + default=GenerativeSearches.ANTHROPIC, frozen=True, exclude=True + ) + model: Optional[str] + maxTokens: Optional[int] + stopSequences: Optional[List[str]] + temperature: Optional[float] + topK: Optional[int] + topP: Optional[float] + + class _RerankerConfigCreate(_ConfigCreateModel): - reranker: Rerankers + reranker: Union[Rerankers, _EnumLikeStr] RerankerCohereModel = Literal["rerank-english-v2.0", "rerank-multilingual-v2.0"] class _RerankerCohereConfig(_RerankerConfigCreate): - reranker: Rerankers = Field(default=Rerankers.COHERE, frozen=True, exclude=True) + reranker: Union[Rerankers, _EnumLikeStr] = Field( + default=Rerankers.COHERE, frozen=True, exclude=True + ) model: Optional[Union[RerankerCohereModel, str]] = Field(default=None) +class _RerankerCohereCustom(_RerankerConfigCreate): + module_config: Dict[str, Any] + + def _to_dict(self) -> Dict[str, Any]: + return self.module_config + + class _RerankerTransformersConfig(_RerankerConfigCreate): - reranker: Rerankers = Field(default=Rerankers.TRANSFORMERS, frozen=True, exclude=True) + reranker: Union[Rerankers, _EnumLikeStr] = Field( + default=Rerankers.TRANSFORMERS, frozen=True, exclude=True + ) RerankerVoyageAIModel = Literal["rerank-lite-1", "rerank-1"] class _RerankerVoyageAIConfig(_RerankerConfigCreate): - reranker: Rerankers = Field(default=Rerankers.VOYAGEAI, frozen=True, exclude=True) + reranker: Union[Rerankers, _EnumLikeStr] = Field( + default=Rerankers.VOYAGEAI, frozen=True, exclude=True + ) model: Optional[Union[RerankerVoyageAIModel, str]] = Field(default=None) @@ -532,6 +567,13 @@ def anyscale( ) -> _GenerativeConfigCreate: return _GenerativeAnyscale(model=model, temperature=temperature) + @staticmethod + def custom( + module_name: str, + module_config: Dict[str, Any], + ) -> _GenerativeConfigCreate: + return _GenerativeCustom(generative=_EnumLikeStr(module_name), module_config=module_config) + @staticmethod def mistral( model: Optional[str] = None, @@ -767,6 +809,41 @@ def aws( endpoint=endpoint, ) + @staticmethod + def anthropic( + model: Optional[str] = None, + max_tokens: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> _GenerativeConfigCreate: + """ + Create a `_GenerativeAnthropicConfig` object for use when performing AI generation using the `generative-anthropic` module. + + Arguments: + `model` + The model to use. Defaults to `None`, which uses the server-defined default + `max_tokens` + The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default + `stop_sequences` + The stop sequences to use. Defaults to `None`, which uses the server-defined default + `temperature` + The temperature to use. Defaults to `None`, which uses the server-defined default + `top_k` + The top K to use. Defaults to `None`, which uses the server-defined default + `top_p` + The top P to use. Defaults to `None`, which uses the server-defined default + """ + return _GenerativeAnthropicConfig( + model=model, + maxTokens=max_tokens, + stopSequences=stop_sequences, + temperature=temperature, + topK=top_k, + topP=top_p, + ) + class _Reranker: """Use this factory class to create the correct object for the `reranker_config` argument in the `collections.create()` method. @@ -784,6 +861,13 @@ def transformers() -> _RerankerConfigCreate: """ return _RerankerTransformersConfig(reranker=Rerankers.TRANSFORMERS) + @staticmethod + def custom(module_name: str, module_config: Dict[str, Any]) -> _RerankerConfigCreate: + """Create a `_RerankerCohereCustom` object for use when reranking using a custom module.""" + return _RerankerCohereCustom( + reranker=_EnumLikeStr(module_name), module_config=module_config + ) + @staticmethod def cohere( model: Optional[Union[RerankerCohereModel, str]] = None, @@ -1212,7 +1296,7 @@ def vector_index_type() -> str: @dataclass class _GenerativeConfig(_ConfigBase): - generative: GenerativeSearches + generative: Union[GenerativeSearches, str] model: Dict[str, Any] @@ -1221,7 +1305,7 @@ class _GenerativeConfig(_ConfigBase): @dataclass class _VectorizerConfig(_ConfigBase): - vectorizer: Vectorizers + vectorizer: Union[Vectorizers, str] model: Dict[str, Any] vectorize_collection_name: bool @@ -1232,7 +1316,7 @@ class _VectorizerConfig(_ConfigBase): @dataclass class _RerankerConfig(_ConfigBase): model: Dict[str, Any] - reranker: Rerankers + reranker: Union[Rerankers, str] RerankerConfig = _RerankerConfig @@ -1240,7 +1324,7 @@ class _RerankerConfig(_ConfigBase): @dataclass class _NamedVectorizerConfig(_ConfigBase): - vectorizer: Vectorizers + vectorizer: Union[Vectorizers, str] model: Dict[str, Any] source_properties: Optional[List[str]] @@ -1283,7 +1367,7 @@ class _CollectionConfig(_ConfigBase): ] vector_index_type: Optional[VectorIndexType] vectorizer_config: Optional[VectorizerConfig] - vectorizer: Optional[Vectorizers] + vectorizer: Optional[Union[Vectorizers, str]] vector_config: Optional[Dict[str, _NamedVectorConfig]] def to_dict(self) -> dict: @@ -1341,7 +1425,7 @@ class _CollectionConfigSimple(_ConfigBase): references: List[ReferencePropertyConfig] reranker_config: Optional[RerankerConfig] vectorizer_config: Optional[VectorizerConfig] - vectorizer: Optional[Vectorizers] + vectorizer: Optional[Union[Vectorizers, str]] vector_config: Optional[Dict[str, _NamedVectorConfig]] @@ -1409,7 +1493,9 @@ def _check_name(cls, v: str) -> str: raise ValueError(f"Property name '{v}' is reserved and cannot be used") return v - def _to_dict(self, vectorizer: Optional[Vectorizers] = None) -> Dict[str, Any]: + def _to_dict( + self, vectorizer: Optional[Union[Vectorizers, _EnumLikeStr]] = None + ) -> Dict[str, Any]: ret_dict = super()._to_dict() ret_dict["dataType"] = [ret_dict["dataType"]] if vectorizer is not None and vectorizer != Vectorizers.NONE: diff --git a/weaviate/collections/classes/config_base.py b/weaviate/collections/classes/config_base.py index 7d721671a..07aae6cae 100644 --- a/weaviate/collections/classes/config_base.py +++ b/weaviate/collections/classes/config_base.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from enum import Enum from typing import Any, Dict, cast + from pydantic import BaseModel, ConfigDict, Field @@ -68,3 +69,12 @@ class _QuantizerConfigUpdate(_ConfigUpdateModel): @abstractmethod def quantizer_name() -> str: ... + + +@dataclass +class _EnumLikeStr: + string: str + + @property + def value(self) -> str: + return self.string diff --git a/weaviate/collections/classes/config_methods.py b/weaviate/collections/classes/config_methods.py index 2ab1461e1..185a0c7ed 100644 --- a/weaviate/collections/classes/config_methods.py +++ b/weaviate/collections/classes/config_methods.py @@ -49,9 +49,13 @@ def __get_rerank_config(schema: Dict[str, Any]) -> Optional[_RerankerConfig]: ) == 1 ): + try: + reranker = Rerankers(rerankers[0]) + except ValueError: + reranker = rerankers[0] return _RerankerConfig( model=schema["moduleConfig"][rerankers[0]], - reranker=Rerankers(rerankers[0]), + reranker=reranker, ) else: return None @@ -66,8 +70,13 @@ def __get_generative_config(schema: Dict[str, Any]) -> Optional[_GenerativeConfi ) == 1 ): + try: + generative = GenerativeSearches(generators[0]) + except ValueError: + generative = generators[0] + return _GenerativeConfig( - generative=GenerativeSearches(generators[0]), + generative=generative, model=schema["moduleConfig"][generators[0]], ) else: @@ -75,22 +84,26 @@ def __get_generative_config(schema: Dict[str, Any]) -> Optional[_GenerativeConfi def __get_vectorizer_config(schema: Dict[str, Any]) -> Optional[_VectorizerConfig]: - if __get_vectorizer(schema) is not None and schema.get("vectorizer", "none") != "none": + if __is_vectorizer_present(schema) is not None and schema.get("vectorizer", "none") != "none": vec_config: Dict[str, Any] = schema["moduleConfig"].pop(schema["vectorizer"]) + try: + vectorizer = Vectorizers(schema["vectorizer"]) + except ValueError: + vectorizer = schema["vectorizer"] return _VectorizerConfig( vectorize_collection_name=vec_config.pop("vectorizeClassName", False), model=vec_config, - vectorizer=Vectorizers(schema["vectorizer"]), + vectorizer=vectorizer, ) else: return None -def __get_vectorizer(schema: Dict[str, Any]) -> Optional[Vectorizers]: +def __is_vectorizer_present(schema: Dict[str, Any]) -> bool: # ignore single vectorizer config if named vectors are present if "vectorConfig" in schema: - return None - return Vectorizers(schema.get("vectorizer")) + return False + return True def __get_vector_index_type(schema: Dict[str, Any]) -> Optional[VectorIndexType]: @@ -197,10 +210,14 @@ def __get_vector_config( vector_index_config = __get_vector_index_config(named_vector) assert vector_index_config is not None + try: + vec: Union[str, Vectorizers] = Vectorizers(vectorizer_str) + except ValueError: + vec = vectorizer_str named_vectors[name] = _NamedVectorConfig( vectorizer=_NamedVectorizerConfig( - vectorizer=Vectorizers(vectorizer_str), + vectorizer=vec, model=vec_config, source_properties=props, ), @@ -211,6 +228,17 @@ def __get_vector_config( return None +def __get_vectorizer(schema: Dict[str, Any]) -> Optional[Union[str, Vectorizers]]: + if "vectorConfig" in schema: + return None + + vectorizer = str(schema["vectorizer"]) + try: + return Vectorizers(vectorizer) + except ValueError: + return vectorizer + + def _collection_config_simple_from_json(schema: Dict[str, Any]) -> _CollectionConfigSimple: return _CollectionConfigSimple( name=schema["class"], diff --git a/weaviate/collections/classes/config_named_vectors.py b/weaviate/collections/classes/config_named_vectors.py index 725c7a5dc..13b3e94e1 100644 --- a/weaviate/collections/classes/config_named_vectors.py +++ b/weaviate/collections/classes/config_named_vectors.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, List, Literal, Optional, Union import warnings +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import AnyHttpUrl, Field from weaviate.collections.classes.config_base import ( _ConfigCreateModel, _ConfigUpdateModel, + _EnumLikeStr, ) from weaviate.collections.classes.config_vector_index import ( _VectorIndexConfigCreate, @@ -46,6 +47,7 @@ Vectorizers, VoyageModel, _map_multi2vec_fields, + _VectorizerCustomConfig, ) @@ -110,6 +112,38 @@ def none( vector_index_config=vector_index_config, ) + @staticmethod + def custom( + name: str, + *, + module_name: str, + source_properties: Optional[List[str]] = None, + vector_index_config: Optional[_VectorIndexConfigCreate] = None, + module_config: Optional[Dict[str, Any]] = None, + ) -> _NamedVectorConfigCreate: + """Create a named vector using no vectorizer. You will need to provide the vectors yourself. + + Arguments: + `name` + The name of the named vector. + `module_name` + The name of the custom module to use. + `module_config` + The configuration of the custom module to use. + `source_properties` + Which properties should be included when vectorizing. By default all text properties are included. + `vector_index_config` + The configuration for Weaviate's vector index. Use wvc.config.Configure.VectorIndex to create a vector index configuration. None by default + """ + return _NamedVectorConfigCreate( + name=name, + source_properties=source_properties, + vectorizer=_VectorizerCustomConfig( + vectorizer=_EnumLikeStr(module_name), module_config=module_config + ), + vector_index_config=vector_index_config, + ) + @staticmethod def text2vec_cohere( name: str, diff --git a/weaviate/collections/classes/config_vectorizers.py b/weaviate/collections/classes/config_vectorizers.py index 653ec2dac..778183430 100644 --- a/weaviate/collections/classes/config_vectorizers.py +++ b/weaviate/collections/classes/config_vectorizers.py @@ -1,11 +1,11 @@ +import warnings from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union, cast -import warnings from pydantic import AnyHttpUrl, BaseModel, Field, field_validator from typing_extensions import TypeAlias -from weaviate.collections.classes.config_base import _ConfigCreateModel +from weaviate.collections.classes.config_base import _ConfigCreateModel, _EnumLikeStr CohereModel: TypeAlias = Literal[ "embed-multilingual-v2.0", @@ -133,22 +133,33 @@ class VectorDistances(str, Enum): class _VectorizerConfigCreate(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=..., exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(default=..., exclude=True) class _Text2VecContextionaryConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field( + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( default=Vectorizers.TEXT2VEC_CONTEXTIONARY, frozen=True, exclude=True ) vectorizeClassName: bool +class _VectorizerCustomConfig(_VectorizerConfigCreate): + module_config: Optional[Dict[str, Any]] + + def _to_dict(self) -> Dict[str, Any]: + if self.module_config is None: + return {} + return self.module_config + + class _Text2VecContextionaryConfigCreate(_Text2VecContextionaryConfig, _VectorizerConfigCreate): pass class _Text2VecAWSConfig(_VectorizerConfigCreate): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_AWS, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_AWS, frozen=True, exclude=True + ) model: Optional[str] endpoint: Optional[str] region: str @@ -167,7 +178,9 @@ class _Text2VecAWSConfigCreate(_Text2VecAWSConfig, _VectorizerConfigCreate): class _Text2VecAzureOpenAIConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_OPENAI, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_OPENAI, frozen=True, exclude=True + ) baseURL: Optional[AnyHttpUrl] resourceName: str deploymentId: str @@ -185,7 +198,7 @@ class _Text2VecAzureOpenAIConfigCreate(_Text2VecAzureOpenAIConfig, _VectorizerCo class _Text2VecHuggingFaceConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field( + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( default=Vectorizers.TEXT2VEC_HUGGINGFACE, frozen=True, exclude=True ) model: Optional[str] @@ -221,7 +234,9 @@ class _Text2VecHuggingFaceConfigCreate(_Text2VecHuggingFaceConfig, _VectorizerCo class _Text2VecOpenAIConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_OPENAI, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_OPENAI, frozen=True, exclude=True + ) baseURL: Optional[AnyHttpUrl] dimensions: Optional[int] model: Optional[str] @@ -243,7 +258,9 @@ class _Text2VecOpenAIConfigCreate(_Text2VecOpenAIConfig, _VectorizerConfigCreate class _Text2VecCohereConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_COHERE, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_COHERE, frozen=True, exclude=True + ) baseURL: Optional[AnyHttpUrl] model: Optional[str] truncate: Optional[CohereTruncation] @@ -261,7 +278,9 @@ class _Text2VecCohereConfigCreate(_Text2VecCohereConfig, _VectorizerConfigCreate class _Text2VecPalmConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_PALM, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_PALM, frozen=True, exclude=True + ) projectId: str apiEndpoint: Optional[str] modelId: Optional[str] @@ -274,7 +293,7 @@ class _Text2VecPalmConfigCreate(_Text2VecPalmConfig, _VectorizerConfigCreate): class _Text2VecTransformersConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field( + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( default=Vectorizers.TEXT2VEC_TRANSFORMERS, frozen=True, exclude=True ) poolingStrategy: Literal["masked_mean", "cls"] @@ -289,7 +308,9 @@ class _Text2VecTransformersConfigCreate(_Text2VecTransformersConfig, _Vectorizer class _Text2VecGPT4AllConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_GPT4ALL, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_GPT4ALL, frozen=True, exclude=True + ) vectorizeClassName: bool @@ -298,7 +319,9 @@ class _Text2VecGPT4AllConfigCreate(_Text2VecGPT4AllConfig, _VectorizerConfigCrea class _Text2VecJinaConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_JINAAI, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_JINAAI, frozen=True, exclude=True + ) model: Optional[str] vectorizeClassName: bool @@ -308,7 +331,7 @@ class _Text2VecJinaConfigCreate(_Text2VecJinaConfig, _VectorizerConfigCreate): class _Text2VecVoyageConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field( + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( default=Vectorizers.TEXT2VEC_VOYAGEAI, frozen=True, exclude=True ) model: Optional[str] @@ -322,21 +345,27 @@ class _Text2VecVoyageConfigCreate(_Text2VecVoyageConfig, _VectorizerConfigCreate class _Text2VecOctoConfig(_VectorizerConfigCreate): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_OCTOAI, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_OCTOAI, frozen=True, exclude=True + ) model: Optional[str] baseURL: Optional[str] vectorizeClassName: bool class _Text2VecOllamaConfig(_VectorizerConfigCreate): - vectorizer: Vectorizers = Field(default=Vectorizers.TEXT2VEC_OLLAMA, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.TEXT2VEC_OLLAMA, frozen=True, exclude=True + ) model: Optional[str] apiEndpoint: Optional[str] vectorizeClassName: bool class _Img2VecNeuralConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.IMG2VEC_NEURAL, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.IMG2VEC_NEURAL, frozen=True, exclude=True + ) imageFields: List[str] @@ -373,7 +402,9 @@ def _to_dict(self) -> Dict[str, Any]: class _Multi2VecClipConfig(_Multi2VecBase): - vectorizer: Vectorizers = Field(default=Vectorizers.MULTI2VEC_CLIP, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.MULTI2VEC_CLIP, frozen=True, exclude=True + ) inferenceUrl: Optional[str] @@ -382,7 +413,9 @@ class _Multi2VecClipConfigCreate(_Multi2VecClipConfig, _VectorizerConfigCreate): class _Multi2VecPalmConfig(_Multi2VecBase, _VectorizerConfigCreate): - vectorizer: Vectorizers = Field(default=Vectorizers.MULTI2VEC_PALM, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.MULTI2VEC_PALM, frozen=True, exclude=True + ) videoFields: Optional[List[Multi2VecField]] projectId: str location: Optional[str] @@ -393,7 +426,9 @@ class _Multi2VecPalmConfig(_Multi2VecBase, _VectorizerConfigCreate): class _Multi2VecBindConfig(_Multi2VecBase): - vectorizer: Vectorizers = Field(default=Vectorizers.MULTI2VEC_BIND, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.MULTI2VEC_BIND, frozen=True, exclude=True + ) audioFields: Optional[List[Multi2VecField]] depthFields: Optional[List[Multi2VecField]] IMUFields: Optional[List[Multi2VecField]] @@ -406,7 +441,9 @@ class _Multi2VecBindConfigCreate(_Multi2VecBindConfig, _VectorizerConfigCreate): class _Ref2VecCentroidConfig(_ConfigCreateModel): - vectorizer: Vectorizers = Field(default=Vectorizers.REF2VEC_CENTROID, frozen=True, exclude=True) + vectorizer: Union[Vectorizers, _EnumLikeStr] = Field( + default=Vectorizers.REF2VEC_CENTROID, frozen=True, exclude=True + ) referenceProperties: List[str] method: Literal["mean"] @@ -652,6 +689,15 @@ def text2vec_contextionary(vectorize_collection_name: bool = True) -> _Vectorize """ return _Text2VecContextionaryConfigCreate(vectorizeClassName=vectorize_collection_name) + @staticmethod + def custom( + module_name: str, module_config: Optional[Dict[str, Any]] = None + ) -> _VectorizerConfigCreate: + """Create a `_VectorizerCustomConfig` object for use when vectorizing using a custom model.""" + return _VectorizerCustomConfig( + vectorizer=_EnumLikeStr(module_name), module_config=module_config + ) + @staticmethod def text2vec_cohere( model: Optional[Union[CohereModel, str]] = None, diff --git a/weaviate/collections/classes/grpc.py b/weaviate/collections/classes/grpc.py index ad38b3b44..9c1f7c7c9 100644 --- a/weaviate/collections/classes/grpc.py +++ b/weaviate/collections/classes/grpc.py @@ -6,8 +6,8 @@ from weaviate.collections.classes.types import _WeaviateInput from weaviate.proto.v1 import search_get_pb2 +from weaviate.str_enum import BaseEnum from weaviate.types import INCLUDE_VECTOR, UUID -from weaviate.util import BaseEnum class HybridFusion(str, BaseEnum): diff --git a/weaviate/collections/classes/tenants.py b/weaviate/collections/classes/tenants.py index 2737c049d..10c9cad2c 100644 --- a/weaviate/collections/classes/tenants.py +++ b/weaviate/collections/classes/tenants.py @@ -16,6 +16,8 @@ class TenantActivityStatus(str, Enum): The tenant is in the process of being frozen. `UNFREEZING` The tenant is in the process of being unfrozen. + `UNFROZEN` + The tenant has been pulled from the cloud and is not yet active nor inactive. """ HOT = "HOT" @@ -48,8 +50,22 @@ def activity_status(self) -> TenantActivityStatus: return self.activityStatus -class TenantActivityStatusInput(str, Enum): - """TenantActivityStatus class used to describe the activity status of a tenant in Weaviate. +class TenantCreateActivityStatus(str, Enum): + """TenantActivityStatus class used to describe the activity status of a tenant to create in Weaviate. + + Attributes: + `HOT` + The tenant is fully active and can be used. + `COLD` + The tenant is not active, files stored locally. + """ + + HOT = "HOT" + COLD = "COLD" + + +class TenantUpdateActivityStatus(str, Enum): + """TenantActivityStatus class used to describe the activity status of a tenant to update in Weaviate. Attributes: `HOT` @@ -65,23 +81,45 @@ class TenantActivityStatusInput(str, Enum): FROZEN = "FROZEN" -class TenantInput(BaseModel): - """Tenant class used to describe a tenant in Weaviate. +class TenantCreate(BaseModel): + """Tenant class used to describe a tenant to create in Weaviate. + + Attributes: + `name` + the name of the tenant. + `activity_status` + TenantCreateActivityStatus, default: "HOT" + """ + + model_config = ConfigDict(populate_by_name=True) + name: str + activityStatus: TenantCreateActivityStatus = Field( + default=TenantCreateActivityStatus.HOT, alias="activity_status" + ) + + @property + def activity_status(self) -> TenantCreateActivityStatus: + """Getter for the activity status of the tenant.""" + return self.activityStatus + + +class TenantUpdate(BaseModel): + """Tenant class used to describe a tenant to create in Weaviate. Attributes: `name` the name of the tenant. `activity_status` - TenantActivityStatusInput, default: "HOT" + TenantUpdateActivityStatus, default: "HOT" """ model_config = ConfigDict(populate_by_name=True) name: str - activityStatus: TenantActivityStatusInput = Field( - default=TenantActivityStatusInput.HOT, alias="activity_status" + activityStatus: TenantUpdateActivityStatus = Field( + default=TenantUpdateActivityStatus.HOT, alias="activity_status" ) @property - def activity_status(self) -> TenantActivityStatusInput: + def activity_status(self) -> TenantUpdateActivityStatus: """Getter for the activity status of the tenant.""" return self.activityStatus diff --git a/weaviate/collections/data/data.py b/weaviate/collections/data/data.py index 033196846..d9efb2f74 100644 --- a/weaviate/collections/data/data.py +++ b/weaviate/collections/data/data.py @@ -46,6 +46,7 @@ ) from weaviate.connect import ConnectionV4 from weaviate.connect.v4 import _ExpectedStatusCodes +from weaviate.logger import logger from weaviate.types import BEACON, UUID, VECTORS from weaviate.util import _datetime_to_string, _get_vector_v4 from weaviate.validator import _validate_input, _ValidateArgument @@ -357,33 +358,39 @@ async def insert_many( `weaviate.exceptions.WeaviateInsertManyAllFailedError`: If every object in the batch fails to be inserted. The exception message contains details about the failure. """ - return await self._batch_grpc.objects( - [ - ( - _BatchObject( - collection=self.name, - vector=obj.vector, - uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()), - properties=cast(dict, obj.properties), - tenant=self._tenant, - references=obj.references, - index=idx, - ) - if isinstance(obj, DataObject) - else _BatchObject( - collection=self.name, - vector=None, - uuid=str(uuid_package.uuid4()), - properties=cast(dict, obj), - tenant=self._tenant, - references=None, - index=idx, - ) + objs = [ + ( + _BatchObject( + collection=self.name, + vector=obj.vector, + uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()), + properties=cast(dict, obj.properties), + tenant=self._tenant, + references=obj.references, + index=idx, ) - for idx, obj in enumerate(objects) - ], - timeout=self._connection.timeout_config.insert, - ) + if isinstance(obj, DataObject) + else _BatchObject( + collection=self.name, + vector=None, + uuid=str(uuid_package.uuid4()), + properties=cast(dict, obj), + tenant=self._tenant, + references=None, + index=idx, + ) + ) + for idx, obj in enumerate(objects) + ] + res = await self._batch_grpc.objects(objs, timeout=self._connection.timeout_config.insert) + if (n_obj_errs := len(res.errors)) > 0: + logger.error( + { + "message": f"Failed to send {n_obj_errs} objects in a batch of {len(objs)}. Please inspect the errors variable of the returned object for more information.", + "errors": res.errors, + } + ) + return res async def replace( self, diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index 5c0d00d9e..cf797e96d 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -57,8 +57,8 @@ ) from weaviate.proto.v1 import search_get_pb2 from weaviate.types import NUMBER, UUID -from weaviate.util import _get_vector_v4 -from weaviate.validator import _ValidateArgument, _validate_input +from weaviate.util import _get_vector_v4, _is_1d_vector +from weaviate.validator import _ValidateArgument, _validate_input, _ExtraTypes # Can be found in the google.protobuf.internal.well_known_types.pyi stub file but is defined explicitly here for clarity. _PyValue: TypeAlias = Union[ @@ -332,7 +332,18 @@ def near_vector( if self._validate_arguments: _validate_input( [ - _ValidateArgument([List, Dict], "near_vector", near_vector), + _ValidateArgument( + [ + List, + Dict, + _ExtraTypes.PANDAS, + _ExtraTypes.POLARS, + _ExtraTypes.NUMPY, + _ExtraTypes.TF, + ], + "near_vector", + near_vector, + ), _ValidateArgument( [str, None, List, _MultiTargetVectorJoin], "target_vector", target_vector ), @@ -342,7 +353,13 @@ def near_vector( certainty, distance = self.__parse_near_options(certainty, distance) targets, target_vectors = self.__target_vector_to_grpc(target_vector) - + invalid_nv_exception = WeaviateInvalidInputError( + f"""near vector argument can be: + - a list of numbers + - a list of lists of numbers for multi target search + - a dictionary with target names as keys and lists of numbers as values + received: {near_vector}""" + ) if isinstance(near_vector, dict): if targets is None or len(targets.target_vectors) != len(near_vector): raise WeaviateInvalidInputError( @@ -351,17 +368,15 @@ def near_vector( vector_per_target: Dict[str, bytes] = {} for key, value in near_vector.items(): + nv = _get_vector_v4(value) + if ( - not isinstance(value, list) - or len(value) == 0 - or not isinstance(value[0], get_args(NUMBER)) + not isinstance(nv, list) + or len(nv) == 0 + or not isinstance(nv[0], get_args(NUMBER)) ): - raise WeaviateQueryError( - "The value of the near_vector dict must be a lists of numbers", - "GRPC", - ) + raise invalid_nv_exception - nv = _get_vector_v4(value) vector_per_target[key] = struct.pack("{}f".format(len(nv)), *nv) near_vector_grpc = search_get_pb2.NearVector( certainty=certainty, @@ -371,16 +386,13 @@ def near_vector( vector_per_target=vector_per_target, ) else: - if not isinstance(near_vector, list) or len(near_vector) == 0: - raise WeaviateInvalidInputError( - """near vector argument can be: - - a list of numbers - - a list of lists of numbers for multi target search - - a dictionary with target names as keys and lists of numbers as values""" - ) + if len(near_vector) == 0: + raise invalid_nv_exception - if isinstance(near_vector[0], get_args(NUMBER)): + if _is_1d_vector(near_vector): near_vector = _get_vector_v4(near_vector) + if not isinstance(near_vector, list): + raise invalid_nv_exception near_vector_grpc = search_get_pb2.NearVector( certainty=certainty, distance=distance, @@ -395,15 +407,13 @@ def near_vector( "The number of target vectors must be equal to the number of vectors." ) for i, vector in enumerate(near_vector): + nv = _get_vector_v4(vector) if ( - not isinstance(vector, list) - or len(vector) == 0 - or not isinstance(vector[0], get_args(NUMBER)) + not isinstance(nv, list) + or len(nv) == 0 + or not isinstance(nv[0], get_args(NUMBER)) ): - raise WeaviateInvalidInputError( - "The value of the near_vector entry must be a lists of numbers" - ) - nv = _get_vector_v4(vector) + raise invalid_nv_exception vector_per_target_tmp[targets.target_vectors[i]] = struct.pack( "{}f".format(len(nv)), *nv ) diff --git a/weaviate/collections/tenants/__init__.py b/weaviate/collections/tenants/__init__.py index f3f876cb2..0ecda5f25 100644 --- a/weaviate/collections/tenants/__init__.py +++ b/weaviate/collections/tenants/__init__.py @@ -1,4 +1,11 @@ -from .tenants import _TenantsAsync, TenantInputType, TenantOutputType +from .tenants import _TenantsAsync, TenantCreateInputType, TenantOutputType, TenantUpdateInputType from .sync import _Tenants -__all__ = ["_Tenants", "_TenantsAsync", "TenantInputType", "TenantOutputType"] +__all__ = [ + "_Tenants", + "_TenantsAsync", + "TenantCreateInputType", + "TenantInputType", + "TenantOutputType", + "TenantUpdateInputType", +] diff --git a/weaviate/collections/tenants/sync.pyi b/weaviate/collections/tenants/sync.pyi index 53a691568..8afe10e12 100644 --- a/weaviate/collections/tenants/sync.pyi +++ b/weaviate/collections/tenants/sync.pyi @@ -1,14 +1,23 @@ from typing import Dict, List, Optional, Sequence, Union -from weaviate.collections.classes.tenants import Tenant, TenantInput -from weaviate.collections.tenants.tenants import _TenantsBase, TenantInputType, TenantOutputType +from weaviate.collections.classes.tenants import Tenant +from weaviate.collections.tenants.tenants import ( + _TenantsBase, + TenantCreateInputType, + TenantOutputType, + TenantUpdateInputType, +) class _Tenants(_TenantsBase): - def create(self, tenants: Union[TenantInputType, Sequence[TenantInputType]]) -> None: ... - def remove(self, tenants: Union[TenantInputType, Sequence[TenantInputType]]) -> None: ... + def create( + self, tenants: Union[TenantCreateInputType, Sequence[TenantCreateInputType]] + ) -> None: ... + def remove(self, tenants: Union[str, Tenant, Sequence[Union[str, Tenant]]]) -> None: ... def get(self) -> Dict[str, TenantOutputType]: ... - def get_by_names(self, tenants: Sequence[TenantInputType]) -> Dict[str, TenantOutputType]: ... - def get_by_name(self, tenant: TenantInputType) -> Optional[TenantOutputType]: ... + def get_by_names( + self, tenants: Sequence[Union[str, Tenant]] + ) -> Dict[str, TenantOutputType]: ... + def get_by_name(self, tenant: Union[str, Tenant]) -> Optional[TenantOutputType]: ... def update( - self, tenants: Union[Tenant, TenantInput, Sequence[Union[Tenant, TenantInput]]] + self, tenants: Union[TenantUpdateInputType, Sequence[TenantUpdateInputType]] ) -> None: ... - def exists(self, tenant: TenantInputType) -> bool: ... + def exists(self, tenant: Union[str, Tenant]) -> bool: ... diff --git a/weaviate/collections/tenants/tenants.py b/weaviate/collections/tenants/tenants.py index 21eab7383..b558d38a0 100644 --- a/weaviate/collections/tenants/tenants.py +++ b/weaviate/collections/tenants/tenants.py @@ -2,9 +2,11 @@ from weaviate.collections.classes.tenants import ( Tenant, - TenantInput, + TenantCreate, + TenantUpdate, TenantActivityStatus, - TenantActivityStatusInput, + TenantCreateActivityStatus, + TenantUpdateActivityStatus, ) from weaviate.collections.classes.config import ConsistencyLevel from weaviate.collections.grpc.tenants import _TenantsGRPC @@ -14,7 +16,9 @@ from weaviate.connect.v4 import _ExpectedStatusCodes -TenantInputType = Union[str, Tenant, TenantInput] + +TenantCreateInputType = Union[str, Tenant, TenantCreate] +TenantUpdateInputType = Union[str, Tenant, TenantUpdate] TenantOutputType = Tenant @@ -44,14 +48,16 @@ class _TenantsAsync(_TenantsBase): the `collection.tenants` class attribute. """ - async def create(self, tenants: Union[TenantInputType, Sequence[TenantInputType]]) -> None: + async def create( + self, tenants: Union[TenantCreateInputType, Sequence[TenantCreateInputType]] + ) -> None: """Create the specified tenants for a collection in Weaviate. The collection must have been created with multi-tenancy enabled. Arguments: `tenants` - A tenant name, `wvc.config.tenants.Tenant` object, or a list of tenants names + A tenant name, `wvc.config.tenants.Tenant`, `wvc.config.tenants.TenantCreateInput` object, or a list of tenants names and/or `wvc.config.tenants.Tenant` objects to add to the given collection. If a string is provided, the tenant will be added with the default activity status of `HOT`. @@ -70,8 +76,8 @@ async def create(self, tenants: Union[TenantInputType, Sequence[TenantInputType] expected=[ str, Tenant, - TenantInput, - Sequence[Union[str, Tenant, TenantInput]], + TenantCreate, + Sequence[Union[str, Tenant, TenantCreate]], ], name="tenants", value=tenants, @@ -82,14 +88,14 @@ async def create(self, tenants: Union[TenantInputType, Sequence[TenantInputType] path = "/schema/" + self._name + "/tenants" await self._connection.post( path=path, - weaviate_object=self.__map_input_tenants(tenants), + weaviate_object=self.__map_create_tenants(tenants), error_msg=f"Collection tenants may not have been added properly for {self._name}", status_codes=_ExpectedStatusCodes( ok_in=200, error=f"Add collection tenants for {self._name}" ), ) - async def remove(self, tenants: Union[TenantInputType, Sequence[TenantInputType]]) -> None: + async def remove(self, tenants: Union[str, Tenant, Sequence[Union[str, Tenant]]]) -> None: """Remove the specified tenants from a collection in Weaviate. The collection must have been created with multi-tenancy enabled. @@ -114,8 +120,7 @@ async def remove(self, tenants: Union[TenantInputType, Sequence[TenantInputType] expected=[ str, Tenant, - TenantInput, - Sequence[Union[str, Tenant, TenantInput]], + Sequence[Union[str, Tenant]], ], name="tenants", value=tenants, @@ -123,10 +128,17 @@ async def remove(self, tenants: Union[TenantInputType, Sequence[TenantInputType] ] ) + tenant_names: List[str] = [] + if isinstance(tenants, str) or isinstance(tenants, Tenant): + tenant_names = [tenants.name if isinstance(tenants, Tenant) else tenants] + else: + for tenant in tenants: + tenant_names.append(tenant.name if isinstance(tenant, Tenant) else tenant) + path = "/schema/" + self._name + "/tenants" await self._connection.delete( path=path, - weaviate_object=self.__map_input_tenant_names(tenants), + weaviate_object=tenant_names, error_msg=f"Collection tenants may not have been deleted for {self._name}", status_codes=_ExpectedStatusCodes( ok_in=200, error=f"Delete collection tenants for {self._name}" @@ -147,17 +159,14 @@ async def __get_with_rest(self) -> Dict[str, TenantOutputType]: return {tenant["name"]: Tenant(**tenant) for tenant in tenant_resp} async def __get_with_grpc( - self, tenants: Optional[Sequence[TenantInputType]] = None + self, tenants: Optional[Sequence[Union[str, Tenant]]] = None ) -> Dict[str, TenantOutputType]: response = await self._grpc.get( - names=[ - tenant.name - if isinstance(tenant, Tenant) or isinstance(tenant, TenantInput) - else tenant - for tenant in tenants - ] - if tenants is not None - else tenants + names=( + [tenant.name if isinstance(tenant, Tenant) else tenant for tenant in tenants] + if tenants is not None + else tenants + ) ) return { @@ -168,36 +177,60 @@ async def __get_with_grpc( for tenant in response.tenants } - def __map_input_tenant(self, tenant: TenantInputType) -> TenantInput: + def __map_create_tenant(self, tenant: TenantCreateInputType) -> TenantCreate: + if isinstance(tenant, str): + return TenantCreate(name=tenant) + if isinstance(tenant, Tenant): + if tenant.activity_status not in [ + TenantActivityStatus.HOT, + TenantActivityStatus.COLD, + ]: + raise WeaviateInvalidInputError( + f"Tenant activity status must be either 'HOT' or 'COLD'. Other statuses are read-only and cannot be set. Tenant: {tenant.name} had status: {tenant.activity_status}" + ) + activity_status = TenantCreateActivityStatus(tenant.activity_status) + return TenantCreate(name=tenant.name, activity_status=activity_status) + return tenant + + def __map_update_tenant(self, tenant: TenantUpdateInputType) -> TenantUpdate: if isinstance(tenant, str): - return TenantInput(name=tenant) + return TenantUpdate(name=tenant) if isinstance(tenant, Tenant): - if tenant.activity_status in [ - TenantActivityStatus.FREEZING, - TenantActivityStatus.UNFROZEN, - TenantActivityStatus.UNFREEZING, + if tenant.activity_status not in [ + TenantActivityStatus.HOT, + TenantActivityStatus.COLD, + TenantActivityStatus.FROZEN, ]: raise WeaviateInvalidInputError( f"Tenant activity status must be one of 'HOT', 'COLD' or 'FROZEN'. Other statuses are read-only and cannot be set. Tenant: {tenant.name} had status: {tenant.activity_status}" ) - activity_status = TenantActivityStatusInput(tenant.activity_status) - return TenantInput(name=tenant.name, activity_status=activity_status) - if isinstance(tenant, TenantInput): - return tenant + activity_status = TenantUpdateActivityStatus(tenant.activity_status) + return TenantUpdate(name=tenant.name, activity_status=activity_status) + return tenant - def __map_input_tenants( - self, tenant: Union[TenantInputType, Sequence[TenantInputType]] + def __map_create_tenants( + self, tenant: Union[str, Tenant, TenantCreate, Sequence[Union[str, Tenant, TenantCreate]]] ) -> List[dict]: - if isinstance(tenant, str) or isinstance(tenant, Tenant) or isinstance(tenant, TenantInput): - return [self.__map_input_tenant(tenant).model_dump()] - return [self.__map_input_tenant(t).model_dump() for t in tenant] + if ( + isinstance(tenant, str) + or isinstance(tenant, Tenant) + or isinstance(tenant, TenantCreate) + ): + return [self.__map_create_tenant(tenant).model_dump()] + else: + return [self.__map_create_tenant(t).model_dump() for t in tenant] - def __map_input_tenant_names( - self, tenant: Union[TenantInputType, Sequence[TenantInputType]] - ) -> List[str]: - if isinstance(tenant, str) or isinstance(tenant, Tenant) or isinstance(tenant, TenantInput): - return [self.__map_input_tenant(tenant).name] - return [self.__map_input_tenant(t).name for t in tenant] + def __map_update_tenants( + self, tenant: Union[str, Tenant, TenantUpdate, Sequence[Union[str, Tenant, TenantUpdate]]] + ) -> List[dict]: + if ( + isinstance(tenant, str) + or isinstance(tenant, Tenant) + or isinstance(tenant, TenantUpdate) + ): + return [self.__map_update_tenant(tenant).model_dump()] + else: + return [self.__map_update_tenant(t).model_dump() for t in tenant] async def get(self) -> Dict[str, TenantOutputType]: """Return all tenants currently associated with a collection in Weaviate. @@ -215,7 +248,9 @@ async def get(self) -> Dict[str, TenantOutputType]: else: return await self.__get_with_rest() - async def get_by_names(self, tenants: Sequence[TenantInputType]) -> Dict[str, TenantOutputType]: + async def get_by_names( + self, tenants: Sequence[Union[str, Tenant]] + ) -> Dict[str, TenantOutputType]: """Return named tenants currently associated with a collection in Weaviate. If the tenant does not exist, it will not be included in the response. @@ -236,14 +271,14 @@ async def get_by_names(self, tenants: Sequence[TenantInputType]) -> Dict[str, Te if self._validate_arguments: _validate_input( _ValidateArgument( - expected=[Sequence[Union[str, Tenant, TenantInput]]], + expected=[Sequence[Union[str, Tenant]]], name="names", value=tenants, ) ) return await self.__get_with_grpc(tenants=tenants) - async def get_by_name(self, tenant: TenantInputType) -> Optional[TenantOutputType]: + async def get_by_name(self, tenant: Union[str, Tenant]) -> Optional[TenantOutputType]: """Return a specific tenant associated with a collection in Weaviate. If the tenant does not exist, `None` will be returned. @@ -263,16 +298,10 @@ async def get_by_name(self, tenant: TenantInputType) -> Optional[TenantOutputTyp self._connection._weaviate_version.check_is_at_least_1_25_0("The 'get_by_name' method") if self._validate_arguments: _validate_input( - _ValidateArgument( - expected=[Union[str, Tenant, TenantInput]], name="tenant", value=tenant - ) + _ValidateArgument(expected=[Union[str, Tenant]], name="tenant", value=tenant) ) response = await self._grpc.get( - names=[ - tenant.name - if isinstance(tenant, Tenant) or isinstance(tenant, TenantInput) - else tenant - ] + names=[tenant.name if isinstance(tenant, Tenant) else tenant] ) if len(response.tenants) == 0: return None @@ -282,7 +311,7 @@ async def get_by_name(self, tenant: TenantInputType) -> Optional[TenantOutputTyp ) async def update( - self, tenants: Union[Tenant, TenantInput, Sequence[Union[Tenant, TenantInput]]] + self, tenants: Union[Tenant, TenantUpdate, Sequence[Union[Tenant, TenantUpdate]]] ) -> None: """Update the specified tenants for a collection in Weaviate. @@ -304,7 +333,7 @@ async def update( if self._validate_arguments: _validate_input( _ValidateArgument( - expected=[Tenant, TenantInput, Sequence[Union[Tenant, TenantInput]]], + expected=[Tenant, TenantUpdate, Sequence[Union[Tenant, TenantUpdate]]], name="tenants", value=tenants, ) @@ -313,14 +342,14 @@ async def update( path = "/schema/" + self._name + "/tenants" await self._connection.put( path=path, - weaviate_object=self.__map_input_tenants(tenants), + weaviate_object=self.__map_update_tenants(tenants), error_msg=f"Collection tenants may not have been updated properly for {self._name}", status_codes=_ExpectedStatusCodes( ok_in=200, error=f"Update collection tenants for {self._name}" ), ) - async def exists(self, tenant: TenantInputType) -> bool: + async def exists(self, tenant: Union[str, Tenant]) -> bool: """Check if a tenant exists for a collection in Weaviate. The collection must have been created with multi-tenancy enabled. @@ -343,16 +372,13 @@ async def exists(self, tenant: TenantInputType) -> bool: if self._validate_arguments: _validate_input( _ValidateArgument( - expected=[str, Tenant, TenantInput, Sequence[Union[str, Tenant, TenantInput]]], + expected=[str, Tenant, Sequence[Union[str, Tenant]]], name="tenant", value=tenant, ) ) - tenant_name = ( - tenant.name if isinstance(tenant, Tenant) or isinstance(tenant, TenantInput) else tenant - ) - + tenant_name = tenant.name if isinstance(tenant, Tenant) else tenant path = "/schema/" + self._name + "/tenants/" + tenant_name response = await self._connection.head( path=path, diff --git a/weaviate/embedded.py b/weaviate/embedded.py index 130373705..d737ec182 100644 --- a/weaviate/embedded.py +++ b/weaviate/embedded.py @@ -20,6 +20,7 @@ from weaviate import exceptions from weaviate.exceptions import WeaviateStartUpError +from weaviate.logger import logger from weaviate.util import _decode_json_response_dict DEFAULT_BINARY_PATH = str(Path.home() / ".cache/weaviate-embedded/") @@ -136,7 +137,7 @@ def ensure_weaviate_binary_exists(self) -> None: + str(hashlib.sha256(self.options.version.encode("utf-8")).hexdigest()), ) if not self._weaviate_binary_path.exists(): - print( + logger.info( f"Binary {self.options.binary_path} did not exist. Downloading binary from {self._download_url}" ) if self._download_url.endswith(".tar.gz"): @@ -185,7 +186,7 @@ def stop(self) -> None: self.process.terminate() self.process.wait() except ProcessLookupError: - print( + logger.info( f"""Tried to stop embedded weaviate process {self.process.pid}. Process was not found. So not doing anything""" ) @@ -193,7 +194,7 @@ def stop(self) -> None: def ensure_running(self) -> None: if self.is_listening() is False: - print( + logger.info( f"Embedded weaviate wasn't listening on ports http:{self.options.port} & grpc:{self.options.grpc_port}, so starting embedded weaviate again" ) self.start() @@ -243,7 +244,7 @@ def start(self) -> None: env=my_env, ) self.process = process - print(f"Started {self.options.binary_path}: process ID {self.process.pid}") + logger.info(f"Started {self.options.binary_path}: process ID {self.process.pid}") self.wait_till_listening() @abstractmethod @@ -264,7 +265,7 @@ def is_listening(self) -> bool: def start(self) -> None: if self.is_listening(): - print(f"embedded weaviate is already listening on port {self.options.port}") + logger.info(f"embedded weaviate is already listening on port {self.options.port}") return super().start() diff --git a/weaviate/gql/get.py b/weaviate/gql/get.py index 47060ef85..368d2cbde 100644 --- a/weaviate/gql/get.py +++ b/weaviate/gql/get.py @@ -7,6 +7,8 @@ from json import dumps from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +import grpc # type: ignore + from weaviate import util from weaviate.connect import Connection from weaviate.data.replication import ConsistencyLevel @@ -28,18 +30,16 @@ MediaType, Sort, ) +from weaviate.proto.v1 import search_get_pb2 +from weaviate.str_enum import BaseEnum +from weaviate.types import UUID from weaviate.util import ( image_encoder_b64, _capitalize_first_letter, get_valid_uuid, file_encoder_b64, - BaseEnum, ) from weaviate.warnings import _Warnings -from weaviate.types import UUID - -from weaviate.proto.v1 import search_get_pb2 -import grpc # type: ignore @dataclass diff --git a/weaviate/logger.py b/weaviate/logger.py new file mode 100644 index 000000000..755ea6f38 --- /dev/null +++ b/weaviate/logger.py @@ -0,0 +1,5 @@ +import os +from logging import getLogger + +logger = getLogger("weaviate-client") +logger.setLevel(os.getenv("WEAVIATE_LOG_LEVEL", "INFO")) diff --git a/weaviate/str_enum.py b/weaviate/str_enum.py new file mode 100644 index 000000000..78e7f6c06 --- /dev/null +++ b/weaviate/str_enum.py @@ -0,0 +1,19 @@ +# MetaEnum and BaseEnum are required to support `in` statements: +# 'ALL' in ConsistencyLevel == True +# 12345 in ConsistencyLevel == False +from enum import EnumMeta, Enum +from typing import Any + + +class MetaEnum(EnumMeta): + def __contains__(cls, item: Any) -> bool: + try: + # when item is type ConsistencyLevel + return item.name in cls.__members__.keys() + except AttributeError: + # when item is type str + return item in cls.__members__.keys() + + +class BaseEnum(Enum, metaclass=MetaEnum): + pass diff --git a/weaviate/util.py b/weaviate/util.py index 61df9a9e4..80ce8e6cb 100644 --- a/weaviate/util.py +++ b/weaviate/util.py @@ -9,7 +9,6 @@ import os import re import uuid as uuid_lib -from enum import Enum, EnumMeta from pathlib import Path from typing import Union, Sequence, Any, Optional, List, Dict, Generator, Tuple, cast @@ -26,6 +25,7 @@ WeaviateUnsupportedFeatureError, ) from weaviate.types import NUMBER, UUIDS, TIME +from weaviate.validator import _is_valid, _ExtraTypes from weaviate.warnings import _Warnings PYPI_PACKAGE_URL = "https://pypi.org/pypi/weaviate-client/json" @@ -36,23 +36,6 @@ BYTES_PER_CHUNK = 65535 # The number of bytes to read per chunk when encoding files ~ 64kb -# MetaEnum and BaseEnum are required to support `in` statements: -# 'ALL' in ConsistencyLevel == True -# 12345 in ConsistencyLevel == False -class MetaEnum(EnumMeta): - def __contains__(cls, item: Any) -> bool: - try: - # when item is type ConsistencyLevel - return item.name in cls.__members__.keys() - except AttributeError: - # when item is type str - return item in cls.__members__.keys() - - -class BaseEnum(Enum, metaclass=MetaEnum): - pass - - def image_encoder_b64(image_or_image_path: Union[str, io.BufferedReader]) -> str: """ Encode a image in a Weaviate understandable format from a binary read file or by providing @@ -461,7 +444,7 @@ def get_vector(vector: Sequence) -> List[float]: ) from None -def _get_vector_v4(vector: Sequence) -> List[float]: +def _get_vector_v4(vector: Any) -> List[float]: try: return get_vector(vector) except TypeError as e: @@ -703,7 +686,9 @@ def _sanitize_str(value: str) -> str: The sanitized string. """ value = strip_newlines(value) - value = re.sub(r'(? datetime.datetime: "".join(string.rsplit(":", 1) if string[-1] != "Z" else string), "%Y-%m-%dT%H:%M:%S%z", ) + + +def __is_list_type(inputs: Any) -> bool: + try: + if len(inputs) == 0: + return False + except TypeError: + return False + + return any( + _is_valid(types, inputs) + for types in [ + List, + _ExtraTypes.TF, + _ExtraTypes.PANDAS, + _ExtraTypes.NUMPY, + _ExtraTypes.POLARS, + ] + ) + + +def _is_1d_vector(inputs: Any) -> bool: + try: + if len(inputs) == 0: + return False + except TypeError: + return False + if __is_list_type(inputs): + return not __is_list_type(inputs[0]) # 2D vectors are not 1D vectors + return False diff --git a/weaviate/validator.py b/weaviate/validator.py index 1b8b44810..7fe11945c 100644 --- a/weaviate/validator.py +++ b/weaviate/validator.py @@ -2,6 +2,7 @@ from typing import Any, List, Sequence, Union, get_args, get_origin from weaviate.exceptions import WeaviateInvalidInputError +from weaviate.str_enum import BaseEnum @dataclass @@ -11,6 +12,13 @@ class _ValidateArgument: value: Any +class _ExtraTypes(str, BaseEnum): + NUMPY = "numpy" + PANDAS = "pandas" + POLARS = "polars" + TF = "tensorflow" + + def _validate_input(inputs: Union[List[_ValidateArgument], _ValidateArgument]) -> None: """Validate the values of the input arguments in comparison to the expected types defined in _ValidateArgument. @@ -20,15 +28,21 @@ def _validate_input(inputs: Union[List[_ValidateArgument], _ValidateArgument]) - if isinstance(inputs, _ValidateArgument): inputs = [inputs] for validate in inputs: - if not any(__is_valid(exp, validate.value) for exp in validate.expected): + if not any(_is_valid(exp, validate.value) for exp in validate.expected): raise WeaviateInvalidInputError( f"Argument '{validate.name}' must be one of: {validate.expected}, but got {type(validate.value)}" ) -def __is_valid(expected: Any, value: Any) -> bool: +def _is_valid(expected: Any, value: Any) -> bool: if expected is None: return value is None + + # check for types that are not installed + # https://stackoverflow.com/questions/12569452/how-to-identify-numpy-types-in-python + if isinstance(expected, _ExtraTypes): + return expected.value in type(value).__module__ + expected_origin = get_origin(expected) if expected_origin is Union: args = get_args(expected)