From 7ca484da0431699c460358584b9e8be102d9cc46 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Thu, 16 Nov 2023 10:25:47 -0800 Subject: [PATCH] feat: add `upsert_datapoints()` to `MatchingEngineIndex` to support streaming update index. PiperOrigin-RevId: 583089201 --- google/cloud/aiplatform/compat/__init__.py | 2 + .../cloud/aiplatform/compat/types/__init__.py | 2 + .../matching_engine/matching_engine_index.py | 49 +++++++++++--- .../aiplatform/test_matching_engine_index.py | 66 +++++++++++++++++-- 4 files changed, 105 insertions(+), 14 deletions(-) diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index f4a5cdde26..168728e5a2 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -88,6 +88,7 @@ types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1 types.index = types.index_v1beta1 types.index_endpoint = types.index_endpoint_v1beta1 + types.index_service = types.index_service_v1beta1 types.io = types.io_v1beta1 types.job_service = types.job_service_v1beta1 types.job_state = types.job_state_v1beta1 @@ -189,6 +190,7 @@ types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1 types.index = types.index_v1 types.index_endpoint = types.index_endpoint_v1 + types.index_service = types.index_service_v1 types.io = types.io_v1 types.job_service = types.job_service_v1 types.job_state = types.job_state_v1 diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index f543fb0114..f299de6537 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -125,6 +125,7 @@ hyperparameter_tuning_job as hyperparameter_tuning_job_v1, index as index_v1, index_endpoint as index_endpoint_v1, + index_service as index_service_v1, io as io_v1, job_service as job_service_v1, job_state as job_state_v1, @@ -204,6 +205,7 @@ matching_engine_deployed_index_ref_v1, index_v1, index_endpoint_v1, + index_service_v1, metadata_service_v1, metadata_schema_v1, metadata_store_v1, diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py index 9e30b7f1b6..d7a1c742ca 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -21,7 +21,7 @@ from google.protobuf import field_mask_pb2 from google.cloud.aiplatform import base from google.cloud.aiplatform.compat.types import ( - index_service_v1beta1 as gca_index_service_v1beta1, + index_service as gca_index_service, matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, matching_engine_index as gca_matching_engine_index, encryption_spec as gca_encryption_spec, @@ -665,6 +665,42 @@ def create_brute_force_index( encryption_spec_key_name=encryption_spec_key_name, ) + def upsert_datapoints( + self, + datapoints: Sequence[gca_matching_engine_index.IndexDatapoint], + ) -> "MatchingEngineIndex": + """Upsert datapoints to this index. + + Args: + datapoints (Sequence[gca_matching_engine_index.IndexDatapoint]): + Required. Datapoints to be upserted to this index. + + Returns: + MatchingEngineIndex - Index resource object + + """ + + self.wait() + + _LOGGER.log_action_start_against_resource( + "Upserting datapoints", + "index", + self, + ) + + self.api_client.upsert_datapoints( + gca_index_service.UpsertDatapointsRequest( + index=self.resource_name, + datapoints=datapoints, + ) + ) + + _LOGGER.log_action_completed_against_resource( + "index", "Upserted datapoints", self + ) + + return self + def remove_datapoints( self, datapoint_ids: Sequence[str], @@ -678,6 +714,7 @@ def remove_datapoints( Returns: MatchingEngineIndex - Index resource object """ + self.wait() _LOGGER.log_action_start_against_resource( @@ -686,19 +723,13 @@ def remove_datapoints( self, ) - remove_lro = self.api_client.remove_datapoints( - gca_index_service_v1beta1.RemoveDatapointsRequest( + self.api_client.remove_datapoints( + gca_index_service.RemoveDatapointsRequest( index=self.resource_name, datapoint_ids=datapoint_ids, ) ) - _LOGGER.log_action_started_against_resource_with_lro( - "Remove datapoints", "index", self.__class__, remove_lro - ) - - self._gca_resource = remove_lro.result(timeout=None) - _LOGGER.log_action_completed_against_resource( "index", "Removed datapoints", self ) diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index 33072ad6fa..e2b6e51d71 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -35,7 +35,7 @@ from google.cloud.aiplatform.compat.types import ( index as gca_index, encryption_spec as gca_encryption_spec, - index_service_v1beta1 as gca_index_service_v1beta1, + index_service as gca_index_service, ) import constants as test_constants @@ -111,8 +111,42 @@ # Encryption spec _TEST_ENCRYPTION_SPEC_KEY_NAME = "TEST_ENCRYPTION_SPEC" -# Streaming update _TEST_DATAPOINT_IDS = ("1", "2") +_TEST_DATAPOINT_1 = gca_index.IndexDatapoint( + datapoint_id="0", + feature_vector=[0.00526886899, -0.0198396724], + restricts=[ + gca_index.IndexDatapoint.Restriction(namespace="Color", allow_list=["red"]) + ], + numeric_restricts=[ + gca_index.IndexDatapoint.NumericRestriction( + namespace="cost", + value_int=1, + ) + ], +) +_TEST_DATAPOINT_2 = gca_index.IndexDatapoint( + datapoint_id="1", + feature_vector=[0.00526886899, -0.0198396724], + numeric_restricts=[ + gca_index.IndexDatapoint.NumericRestriction( + namespace="cost", + value_double=0.1, + ) + ], + crowding_tag=gca_index.IndexDatapoint.CrowdingTag(crowding_attribute="crowding"), +) +_TEST_DATAPOINT_3 = gca_index.IndexDatapoint( + datapoint_id="2", + feature_vector=[0.00526886899, -0.0198396724], + numeric_restricts=[ + gca_index.IndexDatapoint.NumericRestriction( + namespace="cost", + value_float=1.1, + ) + ], +) +_TEST_DATAPOINTS = (_TEST_DATAPOINT_1, _TEST_DATAPOINT_2, _TEST_DATAPOINT_3) def uuid_mock(): @@ -196,13 +230,19 @@ def create_index_mock(): yield create_index_mock +@pytest.fixture +def upsert_datapoints_mock(): + with patch.object( + index_service_client.IndexServiceClient, "upsert_datapoints" + ) as upsert_datapoints_mock: + yield upsert_datapoints_mock + + @pytest.fixture def remove_datapoints_mock(): with patch.object( index_service_client.IndexServiceClient, "remove_datapoints" ) as remove_datapoints_mock: - remove_datapoints_lro_mock = mock.Mock(operation.Operation) - remove_datapoints_mock.return_value = remove_datapoints_lro_mock yield remove_datapoints_mock @@ -509,6 +549,22 @@ def test_create_brute_force_index_backward_compatibility(self, create_index_mock metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_index_mock") + def test_upsert_datapoints(self, upsert_datapoints_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID) + my_index.upsert_datapoints( + datapoints=_TEST_DATAPOINTS, + ) + + upsert_datapoints_request = gca_index_service.UpsertDatapointsRequest( + index=_TEST_INDEX_NAME, + datapoints=_TEST_DATAPOINTS, + ) + + upsert_datapoints_mock.assert_called_once_with(upsert_datapoints_request) + @pytest.mark.usefixtures("get_index_mock") def test_remove_datapoints(self, remove_datapoints_mock): aiplatform.init(project=_TEST_PROJECT) @@ -518,7 +574,7 @@ def test_remove_datapoints(self, remove_datapoints_mock): datapoint_ids=_TEST_DATAPOINT_IDS, ) - remove_datapoints_request = gca_index_service_v1beta1.RemoveDatapointsRequest( + remove_datapoints_request = gca_index_service.RemoveDatapointsRequest( index=_TEST_INDEX_NAME, datapoint_ids=_TEST_DATAPOINT_IDS, )