diff --git a/Makefile b/Makefile index 4f0f887615..71a71e7ae9 100644 --- a/Makefile +++ b/Makefile @@ -340,6 +340,25 @@ test-python-universal-cassandra-no-cloud-providers: not test_snowflake" \ sdk/python/tests + test-python-universal-milvus-online: + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.milvus_repo_configuration \ + PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.milvus\ + python -m pytest -n 8 --integration \ + -k "not test_universal_cli and \ + not test_go_feature_server and \ + not test_feature_logging and \ + not test_reorder_columns and \ + not test_logged_features_validation and \ + not test_lambda_materialization_consistency and \ + not test_offline_write and \ + not test_push_features_to_offline_store and \ + not gcs_registry and \ + not s3_registry and \ + not test_universal_types and \ + not test_snowflake" \ + sdk/python/tests + test-python-universal-singlestore-online: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.singlestore_repo_configuration \ diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py b/sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py new file mode 100644 index 0000000000..c8f1dcd87b --- /dev/null +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py @@ -0,0 +1,49 @@ +import os +import time +from typing import Any, Dict + +from pymilvus import connections +from testcontainers.core.container import DockerContainer + +from tests.integration.feature_repos.universal.online_store_creator import ( + OnlineStoreCreator, +) + + +class MilvusOnlineStoreCreator(OnlineStoreCreator): + def __init__(self, project_name: str, **kwargs): + super().__init__(project_name) + self.container = DockerContainer("milvusdb/milvus:v2.2.9").with_exposed_ports(19530) + + def create_online_store(self) -> Dict[str, Any]: + self.container.start() + # Wait for Milvus server to be ready + host = "localhost" + port = self.container.get_exposed_port(19530) + + max_attempts = 12 + for attempt in range(1, max_attempts + 1): + try: + print(f"Attempting to connect to Milvus at {host}:{port}, attempt {attempt}") + connections.connect(alias='default', host=host, port=port) + if connections.has_connection(alias='default'): + print("Successfully connected to Milvus") + break + except Exception as e: + print(f"Connection attempt failed: {e}") + time.sleep(5) + else: + raise RuntimeError("Cannot connect to Milvus server after multiple attempts") + + return { + "type": "milvus", + "host": host, + "port": int(port), + "index_type": "IVF_FLAT", + "metric_type": "L2", + "embedding_dim": 128, # Adjust based on your embedding dimension + "vector_enabled": True, + } + + def teardown(self): + self.container.stop() diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 83184643f3..ffa8aee20a 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -11,12 +11,14 @@ from pandas.testing import assert_frame_equal from feast import FeatureStore, RepoConfig +from feast.infra.online_stores.contrib.milvus import MilvusOnlineStoreConfig from feast.errors import FeatureViewNotFoundException from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RegistryConfig from feast.utils import _utc_now +from feast.infra.provider import Provider from tests.integration.feature_repos.universal.feature_views import TAGS from tests.utils.cli_repo_creator import CliRunner, get_example_repo @@ -561,3 +563,104 @@ def test_sqlite_vec_import() -> None: """).fetchall() result = [(rowid, round(distance, 2)) for rowid, distance in result] assert result == [(2, 2.39), (1, 2.39)] + +def test_milvus_get_online_documents() -> None: + """ + Test retrieving documents from the online store in local mode. + """ +def test_milvus_get_online_documents() -> None: + """ + Test retrieving documents from the online store in local mode using Milvus. + """ + n = 10 # number of samples - note: we'll actually double it + vector_length = 8 + runner = CliRunner() + with runner.local_repo( + get_example_repo("example_feature_repo_1.py"), "file" + ) as store: + # Configure the online store to use Milvus + new_config = RepoConfig( + project=store.config.project, + registry=store.config.registry, + provider=store.config.provider, + online_store=MilvusOnlineStoreConfig( + type="milvus", + host="localhost", + port=19530, + index_type="IVF_FLAT", + metric_type="L2", + embedding_dim=vector_length, + vector_enabled=True, + ), + entity_key_serialization_version=store.config.entity_key_serialization_version, + ) + store = FeatureStore(config=new_config, repo_path=store.repo_path) + # Apply the new configuration + store.apply([]) + + # Write some data to the feature view + document_embeddings_fv = store.get_feature_view(name="document_embeddings") + + provider: Provider = store._get_provider() + + item_keys = [ + EntityKeyProto( + join_keys=["item_id"], entity_values=[ValueProto.Value(int64_val=i)] + ) + for i in range(n) + ] + data = [] + for item_key in item_keys: + embedding_vector = np.random.random(vector_length).tolist() + data.append( + ( + item_key, + { + "Embeddings": ValueProto.Value( + float_list_val=FloatListProto(val=embedding_vector) + ) + }, + _utc_now(), + _utc_now(), + ) + ) + + provider.online_write_batch( + config=store.config, + table=document_embeddings_fv, + data=data, + progress=None, + ) + + documents_df = pd.DataFrame( + { + "item_id": [i for i in range(n)], + "Embeddings": [ + np.random.random(vector_length).tolist() for _ in range(n) + ], + "event_timestamp": [_utc_now() for _ in range(n)], + } + ) + + store.write_to_online_store( + feature_view_name="document_embeddings", + df=documents_df, + ) + + # For Milvus, get the collection and check the number of entities + collection = provider._online_store._get_collection( + store.config, document_embeddings_fv + ) + record_count = collection.num_entities + assert record_count == len(data) + documents_df.shape[0] + + query_embedding = np.random.random(vector_length).tolist() + + # Retrieve online documents using Milvus + result = store.retrieve_online_documents( + feature="document_embeddings:Embeddings", query=query_embedding, top_k=3 + ).to_dict() + + assert "Embeddings" in result + assert "distance" in result + assert len(result["distance"]) == 3