diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 3f4d836678a..db7d0a7c18a 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -38,6 +38,10 @@ Feature, Featurestore, ) +from google.cloud.aiplatform.matching_engine import ( + MatchingEngineIndex, + MatchingEngineIndexEndpoint, +) from google.cloud.aiplatform.metadata import metadata from google.cloud.aiplatform.models import Endpoint from google.cloud.aiplatform.models import Model @@ -105,6 +109,8 @@ "EntityType", "Feature", "Featurestore", + "MatchingEngineIndex", + "MatchingEngineIndexEndpoint", "ImageDataset", "HyperparameterTuningJob", "Model", diff --git a/google/cloud/aiplatform/_matching_engine/__init__.py b/google/cloud/aiplatform/matching_engine/__init__.py similarity index 80% rename from google/cloud/aiplatform/_matching_engine/__init__.py rename to google/cloud/aiplatform/matching_engine/__init__.py index 362ee40fc22..4616d01bbd0 100644 --- a/google/cloud/aiplatform/_matching_engine/__init__.py +++ b/google/cloud/aiplatform/matching_engine/__init__.py @@ -15,15 +15,15 @@ # limitations under the License. # -from google.cloud.aiplatform._matching_engine.matching_engine_index import ( +from google.cloud.aiplatform.matching_engine.matching_engine_index import ( MatchingEngineIndex, ) -from google.cloud.aiplatform._matching_engine.matching_engine_index_config import ( +from google.cloud.aiplatform.matching_engine.matching_engine_index_config import ( BruteForceConfig as MatchingEngineBruteForceAlgorithmConfig, MatchingEngineIndexConfig as MatchingEngineIndexConfig, TreeAhConfig as MatchingEngineTreeAhAlgorithmConfig, ) -from google.cloud.aiplatform._matching_engine.matching_engine_index_endpoint import ( +from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( MatchingEngineIndexEndpoint, ) diff --git a/google/cloud/aiplatform/_matching_engine/match_service_pb2.py b/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2.py similarity index 100% rename from google/cloud/aiplatform/_matching_engine/match_service_pb2.py rename to google/cloud/aiplatform/matching_engine/_protos/match_service_pb2.py diff --git a/google/cloud/aiplatform/_matching_engine/match_service_pb2_grpc.py b/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2_grpc.py similarity index 98% rename from google/cloud/aiplatform/_matching_engine/match_service_pb2_grpc.py rename to google/cloud/aiplatform/matching_engine/_protos/match_service_pb2_grpc.py index 61aa9fcc382..9c99081a162 100644 --- a/google/cloud/aiplatform/_matching_engine/match_service_pb2_grpc.py +++ b/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2_grpc.py @@ -16,7 +16,7 @@ # # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" -from google.cloud.aiplatform._matching_engine import match_service_pb2 +from google.cloud.aiplatform.matching_engine._protos import match_service_pb2 import grpc diff --git a/google/cloud/aiplatform/_matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py similarity index 99% rename from google/cloud/aiplatform/_matching_engine/matching_engine_index.py rename to google/cloud/aiplatform/matching_engine/matching_engine_index.py index 7ad9f85ad2f..82f37bae9eb 100644 --- a/google/cloud/aiplatform/_matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -25,7 +25,7 @@ matching_engine_index as gca_matching_engine_index, ) from google.cloud.aiplatform import initializer -from google.cloud.aiplatform._matching_engine import matching_engine_index_config +from google.cloud.aiplatform.matching_engine import matching_engine_index_config from google.cloud.aiplatform import utils _LOGGER = base.Logger(__name__) diff --git a/google/cloud/aiplatform/_matching_engine/matching_engine_index_config.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py similarity index 100% rename from google/cloud/aiplatform/_matching_engine/matching_engine_index_config.py rename to google/cloud/aiplatform/matching_engine/matching_engine_index_config.py diff --git a/google/cloud/aiplatform/_matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py similarity index 99% rename from google/cloud/aiplatform/_matching_engine/matching_engine_index_endpoint.py rename to google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 0b50aa672b9..da155496aed 100644 --- a/google/cloud/aiplatform/_matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -21,14 +21,14 @@ from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer -from google.cloud.aiplatform import _matching_engine +from google.cloud.aiplatform import matching_engine from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import ( machine_resources as gca_machine_resources_compat, matching_engine_index_endpoint as gca_matching_engine_index_endpoint, ) -from google.cloud.aiplatform._matching_engine import match_service_pb2 -from google.cloud.aiplatform._matching_engine import match_service_pb2_grpc +from google.cloud.aiplatform.matching_engine._protos import match_service_pb2 +from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc from google.protobuf import field_mask_pb2 import grpc @@ -432,7 +432,7 @@ def _build_deployed_index( def deploy_index( self, - index: _matching_engine.MatchingEngineIndex, + index: matching_engine.MatchingEngineIndex, deployed_index_id: str, display_name: Optional[str] = None, machine_type: Optional[str] = None, diff --git a/tests/system/aiplatform/e2e_base.py b/tests/system/aiplatform/e2e_base.py index 9b12a47c60a..f204aaf3b08 100644 --- a/tests/system/aiplatform/e2e_base.py +++ b/tests/system/aiplatform/e2e_base.py @@ -20,6 +20,7 @@ import os import pytest import uuid + from typing import Any, Dict, Generator from google.api_core import exceptions @@ -29,8 +30,7 @@ from google.cloud.aiplatform import initializer _PROJECT = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") -_PROJECT_NUMBER = os.getenv("PROJECT_NUMBER") -_VPC_NETWORK_NAME = os.getenv("private-net") +_VPC_NETWORK_URI = os.getenv("_VPC_NETWORK_URI") _LOCATION = "us-central1" @@ -136,7 +136,10 @@ def tear_down_resources(self, shared_state: Dict[str, Any]): # Bring all Endpoints to the front of the list # Ensures Models are undeployed first before we attempt deletion shared_state["resources"].sort( - key=lambda r: 1 if isinstance(r, aiplatform.Endpoint) else 2 + key=lambda r: 1 + if isinstance(r, aiplatform.Endpoint) + or isinstance(r, aiplatform.MatchingEngineIndexEndpoint) + else 2 ) for resource in shared_state["resources"]: @@ -146,6 +149,7 @@ def tear_down_resources(self, shared_state: Dict[str, Any]): ( aiplatform.Endpoint, aiplatform.Featurestore, + aiplatform.MatchingEngineIndexEndpoint, ), ): # For endpoint, undeploy model then delete endpoint diff --git a/tests/system/aiplatform/test_matching_engine_index.py b/tests/system/aiplatform/test_matching_engine_index.py index cc93380d7d2..db4f4ac6bf2 100644 --- a/tests/system/aiplatform/test_matching_engine_index.py +++ b/tests/system/aiplatform/test_matching_engine_index.py @@ -16,7 +16,6 @@ # import uuid -import pytest from google.cloud import aiplatform @@ -52,10 +51,6 @@ _TEST_INDEX_ENDPOINT_DISPLAY_NAME = "endpoint_name" _TEST_INDEX_ENDPOINT_DESCRIPTION = "my endpoint" -_TEST_INDEX_ENDPOINT_VPC_NETWORK = "projects/{}/global/networks/{}".format( - e2e_base._PROJECT_NUMBER, e2e_base._VPC_NETWORK_NAME -) - # DEPLOYED INDEX _TEST_DEPLOYED_INDEX_ID = f"deployed_index_id_{uuid.uuid4()}" _TEST_DEPLOYED_INDEX_DISPLAY_NAME = f"deployed_index_display_name_{uuid.uuid4()}" @@ -167,7 +162,6 @@ ] -@pytest.mark.skip(reason="TestMatchingEngine not available") class TestMatchingEngine(e2e_base.TestEndToEnd): _temp_prefix = "temp_vertex_sdk_e2e_matching_engine_test" @@ -226,9 +220,84 @@ def test_create_get_list_matching_engine_index(self, shared_state): assert updated_index.name == get_index.name + # Create endpoint and check that it is listed + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + network=e2e_base._VPC_NETWORK_URI, + labels=_TEST_LABELS, + ) + assert my_index_endpoint.resource_name in [ + index_endpoint.resource_name + for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list() + ] + + assert my_index_endpoint.labels == _TEST_LABELS + assert my_index_endpoint.display_name == _TEST_INDEX_ENDPOINT_DISPLAY_NAME + assert my_index_endpoint.description == _TEST_INDEX_ENDPOINT_DESCRIPTION + + shared_state["resources"].append(my_index_endpoint) + + # Deploy endpoint + my_index_endpoint = my_index_endpoint.deploy_index( + index=index, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + ) + + # Update endpoint + updated_index_endpoint = my_index_endpoint.update( + display_name=_TEST_DISPLAY_NAME_UPDATE, + description=_TEST_DESCRIPTION_UPDATE, + labels=_TEST_LABELS_UPDATE, + ) + + assert updated_index_endpoint.labels == _TEST_LABELS_UPDATE + assert updated_index_endpoint.display_name == _TEST_DISPLAY_NAME_UPDATE + assert updated_index_endpoint.description == _TEST_DESCRIPTION_UPDATE + + # Mutate deployed index + my_index_endpoint.mutate_deployed_index( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED, + max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED, + ) + + deployed_index = my_index_endpoint.deployed_indexes[0] + + assert deployed_index.id == _TEST_DEPLOYED_INDEX_ID + assert deployed_index.index == index.resource_name + assert ( + deployed_index.automatic_resources.min_replica_count + == _TEST_MIN_REPLICA_COUNT_UPDATED + ) + assert ( + deployed_index.automatic_resources.max_replica_count + == _TEST_MAX_REPLICA_COUNT_UPDATED + ) + + # TODO: Test `my_index_endpoint.match` request. This requires running this test in a VPC. + # results = my_index_endpoint.match( + # deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[_TEST_MATCH_QUERY] + # ) + + # assert results[0][0].id == 870 + + # Undeploy index + my_index_endpoint = my_index_endpoint.undeploy_index( + deployed_index_id=deployed_index.id + ) + # Delete index and check that it is no longer listed index.delete() list_indexes = aiplatform.MatchingEngineIndex.list() assert get_index.resource_name not in [ index.resource_name for index in list_indexes ] + + # Delete index endpoint and check that it is no longer listed + my_index_endpoint.delete() + assert my_index_endpoint.resource_name not in [ + index_endpoint.resource_name + for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list() + ] diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index 160d085c4cc..bf4c3d12322 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -167,7 +167,6 @@ def create_index_mock(): yield create_index_mock -@pytest.mark.skip(reason="MatchingEngineIndex not available") class TestMatchingEngineIndex: def setup_method(self): reload(initializer) diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 9d639881216..cd3dd78ed24 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -383,7 +383,6 @@ def create_index_endpoint_mock(): yield create_index_endpoint_mock -@pytest.mark.skip(reason="MatchingEngineIndexEndpoint not available") class TestMatchingEngineIndexEndpoint: def setup_method(self): reload(initializer)