Skip to content

Commit

Permalink
Embedded changes:
Browse files Browse the repository at this point in the history
- Add `use_async_with_embedded`
- Update embedded weaviate version to latest
- Fix embedded tests with v4 client
  • Loading branch information
tsmith023 committed Jul 23, 2024
1 parent 141ffbf commit f74ef6b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 30 deletions.
86 changes: 69 additions & 17 deletions integration_embedded/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,44 @@

import weaviate
from weaviate.classes.init import AdditionalConfig, Timeout
from weaviate.embedded import WEAVIATE_VERSION
from weaviate.exceptions import WeaviateClosedClientError


@pytest.mark.parametrize("timeout", [(1, 2), Timeout(query=1, insert=2, init=2)])
def test_client_with_extra_options(timeout: Union[Tuple[int, int], Timeout]) -> None:
def test_client_with_extra_options(
timeout: Union[Tuple[int, int], Timeout], tmp_path_factory: pytest.TempPathFactory
) -> None:
additional_config = AdditionalConfig(timeout=timeout, trust_env=True)
client = weaviate.connect_to_embedded(
port=8070,
grpc_port=50040,
additional_config=additional_config,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
)
try:
assert client._connection.timeout_config == Timeout(query=1, insert=2, init=2)
finally:
client.close()


def test_connect_and_close_to_embedded() -> None:
def test_connect_and_close_to_embedded(tmp_path_factory: pytest.TempPathFactory) -> None:
# Can't use the default port values as they are already in use by the local instances
client = weaviate.connect_to_embedded(
port=8078, grpc_port=50151, environment_variables={"DISABLE_TELEMETRY": "true"}
port=30668,
grpc_port=50151,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
)
try:
assert client.is_connected()
metadata = client.get_meta()
assert "1.23.7" == metadata["version"]
assert WEAVIATE_VERSION == metadata["version"]
assert client.is_ready()
assert "8078" == metadata["hostname"].split(":")[2]
assert "30668" == metadata["hostname"].split(":")[2]
assert client.is_live()

client.close()
Expand All @@ -42,14 +51,17 @@ def test_connect_and_close_to_embedded() -> None:
client.close()


def test_embedded_as_context_manager() -> None:
default_version = "1.23.7"
def test_embedded_as_context_manager(tmp_path_factory: pytest.TempPathFactory) -> None:
with weaviate.connect_to_embedded(
port=8077, grpc_port=50152, environment_variables={"DISABLE_TELEMETRY": "true"}
port=30668,
grpc_port=50152,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
) as client:
assert client.is_connected()
metadata = client.get_meta()
assert default_version == metadata["version"]
assert WEAVIATE_VERSION == metadata["version"]
assert client.is_ready()
assert client.is_live()

Expand All @@ -58,17 +70,45 @@ def test_embedded_as_context_manager() -> None:
client.get_meta()


def test_embedded_with_wrong_version() -> None:
@pytest.mark.asyncio
async def test_embedded_with_async_as_context_manager(
tmp_path_factory: pytest.TempPathFactory,
) -> None:
async with weaviate.use_async_with_embedded(
port=8076,
grpc_port=50153,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
) as client:
assert client.is_connected()
metadata = await client.get_meta()
assert WEAVIATE_VERSION == metadata["version"]
assert client.is_ready()
assert client.is_live()

assert not client.is_connected()
with pytest.raises(WeaviateClosedClientError):
await client.get_meta()


def test_embedded_with_wrong_version(tmp_path_factory: pytest.TempPathFactory) -> None:
with pytest.raises(weaviate.exceptions.WeaviateEmbeddedInvalidVersionError):
weaviate.connect_to_embedded(
version="this_version_does_not_exist",
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
)


def test_embedded_already_running() -> None:
def test_embedded_already_running(tmp_path_factory: pytest.TempPathFactory) -> None:
client = weaviate.connect_to_embedded(
port=8096, grpc_port=50155, environment_variables={"DISABLE_TELEMETRY": "true"}
port=8096,
grpc_port=50155,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
)
try:
assert client._connection.embedded_db is not None
Expand All @@ -82,22 +122,34 @@ def test_embedded_already_running() -> None:
client.close()


def test_embedded_startup_with_blocked_http_port() -> None:
def test_embedded_startup_with_blocked_http_port(tmp_path_factory: pytest.TempPathFactory) -> None:
client = weaviate.connect_to_embedded(
port=8098, grpc_port=50096, environment_variables={"DISABLE_TELEMETRY": "true"}
port=8098,
grpc_port=50096,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
)
try:
with pytest.raises(weaviate.exceptions.WeaviateStartUpError):
weaviate.connect_to_embedded(
port=8098, grpc_port=50097, environment_variables={"DISABLE_TELEMETRY": "true"}
port=8098,
grpc_port=50097,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
)
finally:
client.close()


def test_embedded_startup_with_blocked_grpc_port() -> None:
def test_embedded_startup_with_blocked_grpc_port(tmp_path_factory: pytest.TempPathFactory) -> None:
client = weaviate.connect_to_embedded(
port=8099, grpc_port=50150, environment_variables={"DISABLE_TELEMETRY": "true"}
port=8099,
grpc_port=50150,
environment_variables={"DISABLE_TELEMETRY": "true"},
persistence_data_path=tmp_path_factory.mktemp("data"),
binary_path=tmp_path_factory.mktemp("bin"),
)
try:
with pytest.raises(weaviate.exceptions.WeaviateStartUpError):
Expand Down
2 changes: 2 additions & 0 deletions weaviate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
connect_to_wcs,
connect_to_weaviate_cloud,
use_async_with_custom,
use_async_with_embedded,
use_async_with_local,
use_async_with_weaviate_cloud,
)
Expand Down Expand Up @@ -75,6 +76,7 @@
"schema",
"types",
"use_async_with_custom",
"use_async_with_embedded",
"use_async_with_local",
"use_async_with_weaviate_cloud",
]
Expand Down
112 changes: 101 additions & 11 deletions weaviate/connect/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from weaviate.client import WeaviateAsyncClient, WeaviateClient
from weaviate.config import AdditionalConfig
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.embedded import EmbeddedOptions
from weaviate.embedded import EmbeddedOptions, WEAVIATE_VERSION
from weaviate.validator import _validate_input, _ValidateArgument


Expand Down Expand Up @@ -226,7 +226,7 @@ def connect_to_embedded(
grpc_port: int = 50050,
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
version: str = "1.23.7",
version: str = WEAVIATE_VERSION,
persistence_data_path: Optional[str] = None,
binary_path: Optional[str] = None,
environment_variables: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -449,10 +449,10 @@ def use_async_with_weaviate_cloud(
... cluster_url="rAnD0mD1g1t5.something.weaviate.cloud",
... auth_credentials=weaviate.classes.init.Auth.api_key("my-api-key"),
... )
>>> client.is_ready()
>>> await client.is_ready()
False # The connection is not ready yet, you must call `await client.connect()` to connect.
... await client.connect()
>>> client.is_ready()
>>> await client.is_ready()
True
>>> await client.close() # Close the connection when you are done with it.
>>> ################## With Context Manager #############################
Expand All @@ -461,7 +461,7 @@ def use_async_with_weaviate_cloud(
... cluster_url="rAnD0mD1g1t5.something.weaviate.cloud",
... auth_credentials=weaviate.classes.init.Auth.api_key("my-api-key"),
... ) as client:
... client.is_ready()
... await client.is_ready()
True
"""
cluster_url, grpc_host = __parse_weaviate_cloud_cluster_url(cluster_url)
Expand Down Expand Up @@ -523,10 +523,10 @@ def use_async_with_local(
... port=8080,
... grpc_port=50051,
... )
>>> client.is_ready()
>>> await client.is_ready()
False # The connection is not ready yet, you must call `await client.connect()` to connect.
... await client.connect()
>>> client.is_ready()
>>> await client.is_ready()
True
>>> await client.close() # Close the connection when you are done with it.
>>> ################## With Context Manager #############################
Expand All @@ -536,7 +536,7 @@ def use_async_with_local(
... port=8080,
... grpc_port=50051,
... ) as client:
... client.is_ready()
... await client.is_ready()
True
>>> # The connection is automatically closed when the context is exited.
"""
Expand All @@ -552,6 +552,96 @@ def use_async_with_local(
)


def use_async_with_embedded(
hostname: str = "127.0.0.1",
port: int = 8079,
grpc_port: int = 50050,
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
version: str = WEAVIATE_VERSION,
persistence_data_path: Optional[str] = None,
binary_path: Optional[str] = None,
environment_variables: Optional[Dict[str, str]] = None,
) -> WeaviateAsyncClient:
"""
Create an async client object ready to connect to an embedded Weaviate instance.
If this is not sufficient for your customization needs then instantiate a `weaviate.WeaviateAsyncClient` instance directly.
This method handles creating the `WeaviateAsyncClient` instance with relevant options to Weaviate Cloud connections but you must manually call `await client.connect()`.
Once you are done with the client you should call `client.close()` to close the connection and free up resources. Alternatively, you can use the client as a context manager
in an `async with` statement, which will automatically open/close the connection when the context is entered/exited. See the examples below for details.
See [the docs](https://weaviate.io/developers/weaviate/installation/embedded#embedded-options) for more details.
Arguments:
`hostname`
The hostname to use for the underlying REST & GraphQL API calls.
`port`
The port to use for the underlying REST and GraphQL API calls.
`grpc_port`
The port to use for the underlying gRPC API.
`headers`
Additional headers to include in the requests, e.g. API keys for Cloud vectorization.
`additional_config`
This includes many additional, rarely used config options. use wvc.init.AdditionalConfig() to configure.
`version`
Weaviate version to be used for the embedded instance.
`persistence_data_path`
Directory where the files making up the database are stored.
When the XDG_DATA_HOME env variable is set, the default value is: `XDG_DATA_HOME/weaviate/`
Otherwise it is: `~/.local/share/weaviate`
`binary_path`
Directory where to download the binary. If deleted, the client will download the binary again.
When the XDG_CACHE_HOME env variable is set, the default value is: `XDG_CACHE_HOME/weaviate-embedded/`
Otherwise it is: `~/.cache/weaviate-embedded`
`environment_variables`
Additional environment variables to be passed to the embedded instance for configuration.
Returns
`weaviate.WeaviateClient`
The client connected to the embedded instance with the required parameters set appropriately.
Examples:
>>> import weaviate
>>> client = weaviate.use_async_with_embedded(
... port=8080,
... grpc_port=50051,
... )
>>> await client.is_ready()
False # The connection is not ready yet, you must call `await client.connect()` to connect.
... await client.connect()
>>> await client.is_ready()
True
################## With Context Manager #############################
>>> import weaviate
>>> async with weaviate.use_async_with_embedded(
... port=8080,
... grpc_port=50051,
... ) as client:
... await client.is_ready()
True
>>> # The connection is automatically closed when the context is exited.
"""
options = EmbeddedOptions(
hostname=hostname,
port=port,
grpc_port=grpc_port,
version=version,
additional_env_vars=environment_variables,
)
if persistence_data_path is not None:
options.persistence_data_path = persistence_data_path
if binary_path is not None:
options.binary_path = binary_path
client = WeaviateAsyncClient(
embedded_options=options,
additional_headers=headers,
additional_config=additional_config,
)
return client


def use_async_with_custom(
http_host: str,
http_port: int,
Expand Down Expand Up @@ -612,10 +702,10 @@ def use_async_with_custom(
... grpc_port=50051,
... grpc_secure=False,
... )
>>> client.is_ready()
>>> await client.is_ready()
False # The connection is not ready yet, you must call `await client.connect()` to connect.
... await client.connect()
>>> client.is_ready()
>>> await client.is_ready()
True
>>> await client.close() # Close the connection when you are done with it.
>>> ################## Async With Context Manager #############################
Expand All @@ -628,7 +718,7 @@ def use_async_with_custom(
... grpc_port=50051,
... grpc_secure=False,
... ) as client:
... client.is_ready()
... await client.is_ready()
True
>>> # The connection is automatically closed when the context is exited.
"""
Expand Down
5 changes: 3 additions & 2 deletions weaviate/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
DEFAULT_PORT = 8079
DEFAULT_GRPC_PORT = 50060

WEAVIATE_VERSION = "1.26.1"


@dataclass
class EmbeddedOptions:
persistence_data_path: str = os.environ.get("XDG_DATA_HOME", DEFAULT_PERSISTENCE_DATA_PATH)
binary_path: str = os.environ.get("XDG_CACHE_HOME", DEFAULT_BINARY_PATH)
version: str = "1.25.8"
version: str = WEAVIATE_VERSION
port: int = DEFAULT_PORT
hostname: str = "127.0.0.1"
additional_env_vars: Optional[Dict[str, str]] = None
Expand All @@ -52,7 +54,6 @@ def get_random_port() -> int:

class _EmbeddedBase:
def __init__(self, options: EmbeddedOptions) -> None:
self.data_bind_port = get_random_port()
self.options = options
self.grpc_port: int = options.grpc_port
self.process: Optional[subprocess.Popen[bytes]] = None
Expand Down

0 comments on commit f74ef6b

Please sign in to comment.