Skip to content

Commit

Permalink
Merge pull request #854 from weaviate/skip_parameter_validation
Browse files Browse the repository at this point in the history
Add option to skip input parameter validation
  • Loading branch information
dirkkul authored Feb 5, 2024
2 parents 4c80f68 + 9fac31e commit 12a2d54
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 162 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ jobs:
- run: |
pip install -r requirements-devel.txt
pip install .
- name: free space
run: sudo rm -rf /usr/local/lib/android
- name: start weaviate
run: /bin/bash ci/start_weaviate.sh ${{ matrix.versions.weaviate }}
- name: Run integration tests with auth secrets
Expand Down Expand Up @@ -158,6 +160,8 @@ jobs:
- run: |
pip install -r requirements-devel.txt
pip install .
- name: free space
run: sudo rm -rf /usr/local/lib/android
- name: start weaviate
run: /bin/bash ci/start_weaviate.sh ${{ matrix.versions.weaviate }}
- name: Run integration tests with auth secrets
Expand Down Expand Up @@ -256,6 +260,8 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: free space
run: sudo rm -rf /usr/local/lib/android
- run: rm -r weaviate
- name: start weaviate
run: /bin/bash ci/start_weaviate.sh ${{ matrix.server }}
Expand Down
3 changes: 3 additions & 0 deletions integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __call__(
name: str,
data_model_props: Optional[Type[Properties]] = None,
data_model_refs: Optional[Type[Properties]] = None,
skip_argument_validation: bool = False,
) -> Collection[Any, Any]:
"""Typing for fixture."""
...
Expand All @@ -169,6 +170,7 @@ def _factory(
name: str,
data_model_props: Optional[Type[Properties]] = None,
data_model_refs: Optional[Type[Properties]] = None,
skip_argument_validation: bool = False,
) -> Collection[Any, Any]:
nonlocal client_fixture, name_fixture
name_fixture = _sanitize_collection_name(name)
Expand All @@ -178,6 +180,7 @@ def _factory(
name=name_fixture,
data_model_properties=data_model_props,
data_model_references=data_model_refs,
skip_argument_validation=skip_argument_validation,
)
return collection

Expand Down
16 changes: 15 additions & 1 deletion integration/test_collection_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel
from pydantic.dataclasses import dataclass as pydantic_dataclass

from integration.conftest import CollectionFactoryGet
from integration.conftest import CollectionFactoryGet, CollectionFactory
from weaviate.collections import Collection
from weaviate.collections.data import _Data
from weaviate.exceptions import InvalidDataModelException
Expand Down Expand Up @@ -80,3 +80,17 @@ def test_get_with_wrong_generics(
assert error.value.args[0] == WRONG_PROPERTIES_ERROR_MSG
else:
assert error.value.args[0] == WRONG_REFERENCES_ERROR_MSG


def test_get_with_skip_validation(
collection_factory_get: CollectionFactoryGet, collection_factory: CollectionFactory
) -> None:
collection_dummy = collection_factory()

collection = collection_factory_get(collection_dummy.name, skip_argument_validation=True)
with pytest.raises(AttributeError):
collection.data.insert(properties=[])
with pytest.raises(TypeError):
collection.query.bm25(query=5) # type: ignore
with pytest.raises(TypeError):
collection.query.near_vector(vector="test") # type: ignore
2 changes: 1 addition & 1 deletion test/collection/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _test_query(query: Callable) -> None:


def test_bad_query_inputs(connection: ConnectionV4) -> None:
query = _QueryCollection(connection, "dummy", None, None, None, None)
query = _QueryCollection(connection, "dummy", None, None, None, None, True)
# fetch_objects
_test_query(lambda: query.fetch_objects(limit="thing"))
_test_query(lambda: query.fetch_objects(offset="wrong"))
Expand Down
3 changes: 2 additions & 1 deletion weaviate/collections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@


class _CollectionBase:
def __init__(self, connection: ConnectionV4, name: str) -> None:
def __init__(self, connection: ConnectionV4, name: str, validate_arguments: bool) -> None:
self._connection = connection
self.name = _capitalize_first_letter(name)
self.__cluster = _Cluster(connection)
self._validate_arguments = validate_arguments

def shards(self) -> List[Shard]:
"""
Expand Down
44 changes: 30 additions & 14 deletions weaviate/collections/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from weaviate.collections.iterator import _ObjectIterator
from weaviate.collections.query import _GenerateCollection, _QueryCollection
from weaviate.collections.tenants import _Tenants
from weaviate.validator import _validate_input, _ValidateArgument
from weaviate.connect import ConnectionV4
from weaviate.validator import _validate_input, _ValidateArgument


class Collection(_CollectionBase, Generic[Properties, References]):
Expand Down Expand Up @@ -60,12 +60,13 @@ def __init__(
self,
connection: ConnectionV4,
name: str,
validate_arguments: bool,
consistency_level: Optional[ConsistencyLevel] = None,
tenant: Optional[str] = None,
properties: Optional[Type[Properties]] = None,
references: Optional[Type[References]] = None,
) -> None:
super().__init__(connection, name)
super().__init__(connection, name, validate_arguments)

self.aggregate = _AggregateCollection(
self._connection, self.name, consistency_level, tenant
Expand All @@ -78,15 +79,27 @@ def __init__(
self.config = _ConfigCollection(self._connection, self.name, tenant)
"""This namespace includes all the CRUD methods available to you when modifying the configuration of the collection in Weaviate."""
self.data = _DataCollection[Properties](
connection, self.name, consistency_level, tenant, properties
connection, self.name, consistency_level, tenant, validate_arguments, properties
)
"""This namespace includes all the CUD methods available to you when modifying the data of the collection in Weaviate."""
self.generate = _GenerateCollection(
connection, self.name, consistency_level, tenant, properties, references
connection,
self.name,
consistency_level,
tenant,
properties,
references,
validate_arguments,
)
"""This namespace includes all the querying methods available to you when using Weaviate's generative capabilities."""
self.query = _QueryCollection[Properties, References](
connection, self.name, consistency_level, tenant, properties, references
connection,
self.name,
consistency_level,
tenant,
properties,
references,
validate_arguments,
)
"""This namespace includes all the querying methods available to you when using Weaviate's standard query capabilities."""
self.tenants = _Tenants(connection, self.name)
Expand Down Expand Up @@ -120,6 +133,7 @@ def with_tenant(
return Collection[Properties, References](
self._connection,
self.name,
self._validate_arguments,
self.__consistency_level,
tenant.name if isinstance(tenant, Tenant) else tenant,
self.__properties,
Expand All @@ -140,18 +154,20 @@ def with_consistency_level(
`consistency_level`
The consistency level to use.
"""
_validate_input(
[
_ValidateArgument(
expected=[ConsistencyLevel, None],
name="consistency_level",
value=consistency_level,
)
]
)
if self._validate_arguments:
_validate_input(
[
_ValidateArgument(
expected=[ConsistencyLevel, None],
name="consistency_level",
value=consistency_level,
)
]
)
return Collection[Properties, References](
self._connection,
self.name,
self._validate_arguments,
consistency_level,
self.__tenant,
self.__properties,
Expand Down
24 changes: 17 additions & 7 deletions weaviate/collections/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def create(
vectorizer_config: Optional[_VectorizerConfigCreate] = None,
data_model_properties: Optional[Type[Properties]] = None,
data_model_references: Optional[Type[References]] = None,
skip_argument_validation: bool = False,
) -> Collection[Properties, References]:
"""Use this method to create a collection in Weaviate and immediately return a collection object.
Expand Down Expand Up @@ -86,6 +87,8 @@ def create(
The generic class that you want to use to represent the properties of objects in this collection. See the `get` method for more information.
`data_model_references`
The generic class that you want to use to represent the references of objects in this collection. See the `get` method for more information.
`skip_argument_validation`
If arguments to functions such as near_vector should be validated. Disable this if you need to squeeze out some extra performance.
Raises:
`weaviate.WeaviateConnectionError`
Expand Down Expand Up @@ -113,13 +116,19 @@ def create(
assert (
config.name == name
), f"Name of created collection ({name}) does not match given name ({config.name})"
return self.get(name, data_model_properties, data_model_references)
return self.get(
name,
data_model_properties,
data_model_references,
skip_argument_validation=skip_argument_validation,
)

def get(
self,
name: str,
data_model_properties: Optional[Type[Properties]] = None,
data_model_references: Optional[Type[References]] = None,
skip_argument_validation: bool = False,
) -> Collection[Properties, References]:
"""Use this method to return a collection object to be used when interacting with your Weaviate collection.
Expand All @@ -136,22 +145,23 @@ def get(
The generic class that you want to use to represent the objects of references in this collection when mutating objects through the `.query` namespace.
The generic provided in this argument will propagate to the methods in `.query` and allow you to do `mypy` static type checking on your codebase.
If you do not provide a generic, the methods in `.query` will return properties of referenced objects as `Dict[str, Any]`.
`skip_argument_validation`
If arguments to functions such as near_vector should be validated. Disable this if you need to squeeze out some extra performance.
Raises:
`weaviate.exceptions.InvalidDataModelException`
If the data model is not a valid data model, i.e., it is not a `dict` nor a `TypedDict`.
"""
_validate_input(
[_ValidateArgument(expected=[str], name="name", value=name)],
)
_check_properties_generic(data_model_properties)
_check_references_generic(data_model_references)
if not skip_argument_validation:
_validate_input([_ValidateArgument(expected=[str], name="name", value=name)])
_check_properties_generic(data_model_properties)
_check_references_generic(data_model_references)
name = _capitalize_first_letter(name)
return Collection[Properties, References](
self._connection,
name,
properties=data_model_properties,
references=data_model_references,
validate_arguments=not skip_argument_validation,
)

def delete(self, name: Union[str, List[str]]) -> None:
Expand Down
Loading

0 comments on commit 12a2d54

Please sign in to comment.