diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index eb2074bcca..3ccfc06fb6 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -956,6 +956,10 @@ def find_neighbors( queries: List[List[float]], num_neighbors: int = 10, filter: Optional[List[Namespace]] = [], + per_crowding_attribute_neighbor_count: Optional[int] = None, + approx_num_neighbors: Optional[int] = None, + fraction_leaf_nodes_to_search_override: Optional[float] = None, + return_full_datapoint: bool = False, ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint. @@ -979,25 +983,58 @@ def find_neighbors( For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints that satisfy "red color" but not include datapoints with "squared shape". Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + + per_crowding_attribute_neighbor_count (int): + Optional. Crowding is a constraint on a neighbor list produced + by nearest neighbor search requiring that no more than some + value k' of the k neighbors returned have the same value of + crowding_attribute. It's used for improving result diversity. + This field is the maximum number of matches with the same crowding tag. + + approx_num_neighbors (int): + Optional. The number of neighbors to find via approximate search + before exact reordering is performed. If not set, the default + value from scam config is used; if set, this value must be > 0. + + fraction_leaf_nodes_to_search_override (float): + Optional. The fraction of the number of leaves to search, set at + query time allows user to tune search performance. This value + increase result in both search accuracy and latency increase. + The value should be between 0.0 and 1.0. + + return_full_datapoint (bool): + Optional. If set to true, the full datapoints (including all + vector values and of the nearest neighbors are returned. + Note that returning full datapoint will significantly increase the + latency and cost of the query. + Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. """ if not self._public_match_client: raise ValueError( - "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method." + "Please make sure index has been deployed to public endpoint,and follow the example usage to call this method." ) # Create the FindNeighbors request find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest() find_neighbors_request.index_endpoint = self.resource_name find_neighbors_request.deployed_index_id = deployed_index_id + find_neighbors_request.return_full_datapoint = return_full_datapoint for query in queries: find_neighbors_query = ( gca_match_service_v1beta1.FindNeighborsRequest.Query() ) find_neighbors_query.neighbor_count = num_neighbors + find_neighbors_query.per_crowding_attribute_neighbor_count = ( + per_crowding_attribute_neighbor_count + ) + find_neighbors_query.approximate_neighbor_count = approx_num_neighbors + find_neighbors_query.fraction_leaf_nodes_to_search_override = ( + fraction_leaf_nodes_to_search_override + ) datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query) for namespace in filter: restrict = gca_index_v1beta1.IndexDatapoint.Restriction() diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 37d0a7af79..48e7c3c506 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -234,6 +234,8 @@ _TEST_IDS = ["123", "456", "789"] _TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3 _TEST_APPROX_NUM_NEIGHBORS = 2 +_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8 +_TEST_RETURN_FULL_DATAPOINT = True def uuid_mock(): @@ -954,6 +956,10 @@ def test_index_public_endpoint_match_queries( queries=_TEST_QUERIES, num_neighbors=_TEST_NUM_NEIGHBOURS, filter=_TEST_FILTER, + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, ) find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest( @@ -972,8 +978,12 @@ def test_index_public_endpoint_match_queries( ) ], ), + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, ) ], + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, ) index_public_endpoint_match_queries_mock.assert_called_with(