Skip to content

Commit

Permalink
Merge pull request #1425 from weaviate/wes_suport
Browse files Browse the repository at this point in the history
Weaviate embedding service support
  • Loading branch information
tsmith023 authored Nov 29, 2024
2 parents 7f36591 + 75ffad5 commit 6339fcb
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 56 deletions.
39 changes: 39 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,22 @@ def test_basic_config():
}
},
),
(
Configure.Vectorizer.text2vec_weaviate(
vectorize_collection_name=False,
model="Snowflake/snowflake-arctic-embed-m-v1.5",
base_url="https://api.embedding.weaviate.io",
dimensions=768,
),
{
"text2vec-weaviate": {
"vectorizeClassName": False,
"model": "Snowflake/snowflake-arctic-embed-m-v1.5",
"baseURL": "https://api.embedding.weaviate.io",
"dimensions": 768,
}
},
),
(
Configure.Vectorizer.img2vec_neural(
image_fields=["test"],
Expand Down Expand Up @@ -1495,6 +1511,29 @@ def test_vector_config_flat_pq() -> None:
}
},
),
(
[
Configure.NamedVectors.text2vec_weaviate(
name="test",
source_properties=["prop"],
base_url="https://api.embedding.weaviate.io",
dimensions=768,
)
],
{
"test": {
"vectorizer": {
"text2vec-weaviate": {
"properties": ["prop"],
"vectorizeClassName": True,
"baseURL": "https://api.embedding.weaviate.io",
"dimensions": 768,
}
},
"vectorIndexType": "hnsw",
}
},
),
(
[
Configure.NamedVectors.img2vec_neural(
Expand Down
4 changes: 1 addition & 3 deletions weaviate/collections/batch/grpc_batch_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from grpc.aio import AioRpcError # type: ignore


from weaviate.collections.classes.batch import (
DeleteManyObject,
DeleteManyReturn,
Expand All @@ -26,7 +25,6 @@ def __init__(self, connection: ConnectionV4, consistency_level: Optional[Consist
async def batch_delete(
self, name: str, filters: _Filters, verbose: bool, dry_run: bool, tenant: Optional[str]
) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]:
metadata = self._get_metadata()
try:
assert self._connection.grpc_stub is not None
res = await self._connection.grpc_stub.BatchDelete(
Expand All @@ -38,7 +36,7 @@ async def batch_delete(
tenant=tenant,
filters=_FilterToGRPC.convert(filters),
),
metadata=metadata,
metadata=self._connection.grpc_headers(),
timeout=self._connection.timeout_config.insert,
)
res = cast(batch_delete_pb2.BatchDeleteReply, res)
Expand Down
7 changes: 3 additions & 4 deletions weaviate/collections/batch/grpc_batch_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
import uuid as uuid_package
from typing import Any, Dict, List, Optional, Union, cast

from grpc.aio import AioRpcError # type: ignore
from google.protobuf.struct_pb2 import Struct
from grpc.aio import AioRpcError # type: ignore

from weaviate.collections.classes.batch import (
ErrorObject,
_BatchObject,
BatchObjectReturn,
)
from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.collections.classes.types import GeoCoordinate, PhoneNumber
from weaviate.collections.classes.internal import ReferenceToMulti, ReferenceInputs
from weaviate.collections.classes.types import GeoCoordinate, PhoneNumber
from weaviate.collections.grpc.shared import _BaseGRPC
from weaviate.connect import ConnectionV4
from weaviate.exceptions import (
Expand Down Expand Up @@ -135,15 +135,14 @@ async def objects(
async def __send_batch(
self, batch: List[batch_pb2.BatchObject], timeout: Union[int, float]
) -> Dict[int, str]:
metadata = self._get_metadata()
try:
assert self._connection.grpc_stub is not None
res = await self._connection.grpc_stub.BatchObjects(
batch_pb2.BatchObjectsRequest(
objects=batch,
consistency_level=self._consistency_level,
),
metadata=metadata,
metadata=self._connection.grpc_headers(),
timeout=timeout,
)
res = cast(batch_pb2.BatchObjectsReply, res)
Expand Down
25 changes: 25 additions & 0 deletions weaviate/collections/classes/config_named_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
_Text2VecVoyageConfig,
_Multi2VecCohereConfig,
_Multi2VecJinaConfig,
_Text2VecWeaviateConfig,
WeaviateModel,
)
from ...warnings import _Warnings

Expand Down Expand Up @@ -1221,6 +1223,29 @@ def text2vec_voyageai(
vector_index_config=vector_index_config,
)

@staticmethod
def text2vec_weaviate(
name: str,
*,
source_properties: Optional[List[str]] = None,
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
vectorize_collection_name: bool = True,
model: Optional[Union[WeaviateModel, str]] = None,
base_url: Optional[str] = None,
dimensions: Optional[int] = None,
) -> _NamedVectorConfigCreate:
return _NamedVectorConfigCreate(
name=name,
source_properties=source_properties,
vectorizer=_Text2VecWeaviateConfig(
model=model,
vectorizeClassName=vectorize_collection_name,
baseURL=base_url,
dimensions=dimensions,
),
vector_index_config=vector_index_config,
)


class _NamedVectorsUpdate:
@staticmethod
Expand Down
30 changes: 30 additions & 0 deletions weaviate/collections/classes/config_vectorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"bedrock",
"sagemaker",
]
WeaviateModel: TypeAlias = Literal["Snowflake/snowflake-arctic-embed-m-v1.5"]


class Vectorizers(str, Enum):
Expand Down Expand Up @@ -95,6 +96,8 @@ class Vectorizers(str, Enum):
Weaviate module backed by Jina AI text-based embedding models.
`TEXT2VEC_VOYAGEAI`
Weaviate module backed by Voyage AI text-based embedding models.
`TEXT2VEC_WEAVIATE`
Weaviate module backed by Weaviate's self-hosted text-based embedding models.
`IMG2VEC_NEURAL`
Weaviate module backed by a ResNet-50 neural network for images.
`MULTI2VEC_CLIP`
Expand All @@ -121,6 +124,7 @@ class Vectorizers(str, Enum):
TEXT2VEC_TRANSFORMERS = "text2vec-transformers"
TEXT2VEC_JINAAI = "text2vec-jinaai"
TEXT2VEC_VOYAGEAI = "text2vec-voyageai"
TEXT2VEC_WEAVIATE = "text2vec-weaviate"
IMG2VEC_NEURAL = "img2vec-neural"
MULTI2VEC_CLIP = "multi2vec-clip"
MULTI2VEC_COHERE = "multi2vec-cohere"
Expand Down Expand Up @@ -343,6 +347,16 @@ class _Text2VecVoyageConfig(_VectorizerConfigCreate):
vectorizeClassName: bool


class _Text2VecWeaviateConfig(_VectorizerConfigCreate):
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
default=Vectorizers.TEXT2VEC_WEAVIATE, frozen=True, exclude=True
)
model: Optional[str]
baseURL: Optional[str]
vectorizeClassName: bool
dimensions: Optional[int]


class _Text2VecOllamaConfig(_VectorizerConfigCreate):
vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
default=Vectorizers.TEXT2VEC_OLLAMA, frozen=True, exclude=True
Expand Down Expand Up @@ -1290,3 +1304,19 @@ def text2vec_voyageai(
truncate=truncate,
vectorizeClassName=vectorize_collection_name,
)

@staticmethod
def text2vec_weaviate(
*,
model: Optional[Union[WeaviateModel, str]] = None,
base_url: Optional[str] = None,
vectorize_collection_name: bool = True,
dimensions: Optional[int] = None,
) -> _VectorizerConfigCreate:
"""TODO: add docstrings when the documentation is available."""
return _Text2VecWeaviateConfig(
model=model,
baseURL=base_url,
vectorizeClassName=vectorize_collection_name,
dimensions=dimensions,
)
20 changes: 1 addition & 19 deletions weaviate/collections/grpc/shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, List
from typing import Optional

from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.connect import ConnectionV4
Expand All @@ -14,24 +14,6 @@ def __init__(
self._connection = connection
self._consistency_level = self._get_consistency_level(consistency_level)

def _get_metadata(self) -> Optional[Tuple[Tuple[str, str], ...]]:
metadata: Optional[Tuple[Tuple[str, str], ...]] = None
access_token = self._connection.get_current_bearer_token()

metadata_list: List[Tuple[str, str]] = []
if len(access_token) > 0:
metadata_list.append(("authorization", access_token))

if len(self._connection.additional_headers):
for key, val in self._connection.additional_headers.items():
if val is not None:
metadata_list.append((key.lower(), val))

if len(metadata_list) > 0:
metadata = tuple(metadata_list)

return metadata

@staticmethod
def _get_consistency_level(
consistency_level: Optional[ConsistencyLevel],
Expand Down
49 changes: 19 additions & 30 deletions weaviate/connect/v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
self.__loop = loop

self._headers = {"content-type": "application/json"}
self.__add_weaviate_embedding_service_header(connection_params.http.host)
if additional_headers is not None:
_validate_input(_ValidateArgument([dict], "additional_headers", additional_headers))
self.__additional_headers = additional_headers
Expand All @@ -141,6 +142,12 @@ def __init__(

self._prepare_grpc_headers()

def __add_weaviate_embedding_service_header(self, wcd_host: str) -> None:
if not is_weaviate_domain(wcd_host) or not isinstance(self._auth, AuthApiKey):
return
self._headers["X-Weaviate-Api-Key"] = self._auth.api_key
self._headers["X-Weaviate-Cluster-URL"] = "https://" + wcd_host

async def connect(self, skip_init_checks: bool) -> None:
self.__connected = True

Expand Down Expand Up @@ -655,7 +662,17 @@ def _prepare_grpc_headers(self) -> None:

if self._auth is not None:
if isinstance(self._auth, AuthApiKey):
self.__metadata_list.append(("authorization", self._auth.api_key))
if (
"X-Weaviate-Cluster-URL" in self._headers
and "X-Weaviate-Api-Key" in self._headers
):
self.__metadata_list.append(
("x-weaviate-cluster-url", self._headers["X-Weaviate-Cluster-URL"])
)
self.__metadata_list.append(
("x-weaviate-api-key", self._headers["X-Weaviate-Api-Key"])
)
self.__metadata_list.append(("authorization", "Bearer " + self._auth.api_key))
else:
self.__metadata_list.append(
("authorization", "dummy_will_be_refreshed_for_each_call")
Expand All @@ -667,7 +684,7 @@ def _prepare_grpc_headers(self) -> None:
self.__grpc_headers = None

def grpc_headers(self) -> Optional[Tuple[Tuple[str, str], ...]]:
if self._auth is None or not isinstance(self._auth, AuthApiKey):
if self._auth is None or isinstance(self._auth, AuthApiKey):
return self.__grpc_headers

assert self.__grpc_headers is not None
Expand All @@ -676,34 +693,6 @@ def grpc_headers(self) -> Optional[Tuple[Tuple[str, str], ...]]:
self.__metadata_list[len(self.__metadata_list) - 1] = ("authorization", access_token)
return tuple(self.__metadata_list)

# async def _ping_grpc(self) -> None:
# """Performs a grpc health check and raises WeaviateGRPCUnavailableError if not."""
# if not self.is_connected():
# raise WeaviateClosedClientError()
# assert self._grpc_channel is not None
# try:
# request = self._grpc_channel.request(
# "/grpc.health.v1.Health/Check",
# Cardinality.UNARY_UNARY,
# health_pb2.HealthCheckRequest,
# health_pb2.HealthCheckResponse,
# timeout=self.timeout_config.init,
# )
# async with request as stream:
# await stream.send_message(health_pb2.HealthCheckRequest())
# res = await stream.recv_message()
# await stream.end()
# if res is None or res.status != health_pb2.HealthCheckResponse.SERVING:
# self.__connected = False
# raise WeaviateGRPCUnavailableError(
# f"v{self.server_version}", self._connection_params._grpc_address
# )
# except Exception as e:
# self.__connected = False
# raise WeaviateGRPCUnavailableError(
# f"v{self.server_version}", self._connection_params._grpc_address
# ) from e

async def _ping_grpc(self) -> None:
"""Performs a grpc health check and raises WeaviateGRPCUnavailableError if not."""
if not self.is_connected():
Expand Down

0 comments on commit 6339fcb

Please sign in to comment.