Skip to content

Commit

Permalink
feat: Support filters in matching engine vector matching (#1608)
Browse files Browse the repository at this point in the history
* feat: support filter in index_enpoint.match()

* fix type error

* Add unit test for index_endpoint.match()

* update docstring example

* Update docstring

Co-authored-by: nayaknishant <[email protected]>
  • Loading branch information
jaycee-li and nayaknishant authored Aug 26, 2022
1 parent 66b5471 commit d591d3e
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Tuple

from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -51,6 +51,25 @@ class MatchNeighbor:
distance: float


@dataclass
class Namespace:
"""Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces.
Args:
name (str):
Required. The name of this Namespace.
allow_tokens (List(str)):
Optional. The allowed tokens in the namespace.
deny_tokens (List(str)):
Optional. The denied tokens in the namespace. When a token is denied, then matches will be excluded whenever the other datapoint has that token.
For example, if a query specifies [Namespace("color", ["red","blue"], ["purple"])], then that query will match datapoints that are red or blue,
but if those points are also purple, then they will be excluded even if they are red/blue.
"""

name: str
allow_tokens: list = field(default_factory=list)
deny_tokens: list = field(default_factory=list)


class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
"""Matching Engine index endpoint resource for Vertex AI."""

Expand Down Expand Up @@ -796,7 +815,11 @@ def description(self) -> str:
return self._gca_resource.description

def match(
self, deployed_index_id: str, queries: List[List[float]], num_neighbors: int = 1
self,
deployed_index_id: str,
queries: List[List[float]],
num_neighbors: int = 1,
filter: Optional[List[Namespace]] = [],
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index.
Expand All @@ -808,6 +831,11 @@ def match(
num_neighbors (int):
Required. The number of nearest neighbors to be retrieved from database for
each query.
filter (List[Namespace]):
Optional. A list of Namespaces for filtering the matching results.
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
that satisfy "red color" but not include datapoints with "squared shape".
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
Expand Down Expand Up @@ -836,16 +864,22 @@ def match(
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex()
)
batch_request_for_index.deployed_index_id = deployed_index_id
batch_request_for_index.requests.extend(
[
match_service_pb2.MatchRequest(
num_neighbors=num_neighbors,
deployed_index_id=deployed_index_id,
float_val=query,
)
for query in queries
]
)
requests = []
for query in queries:
request = match_service_pb2.MatchRequest(
num_neighbors=num_neighbors,
deployed_index_id=deployed_index_id,
float_val=query,
)
for namespace in filter:
restrict = match_service_pb2.Namespace()
restrict.name = namespace.name
restrict.allow_tokens.extend(namespace.allow_tokens)
restrict.deny_tokens.extend(namespace.deny_tokens)
request.restricts.append(restrict)
requests.append(request)

batch_request_for_index.requests.extend(requests)
batch_request.requests.append(batch_request_for_index)

# Perform the request
Expand Down
15 changes: 15 additions & 0 deletions tests/system/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import uuid

from google.cloud import aiplatform
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
)

from tests.system.aiplatform import e2e_base

Expand Down Expand Up @@ -161,6 +164,8 @@
-0.021106,
]

_TEST_FILTER = [Namespace("name", ["allow_token"], ["deny_token"])]


class TestMatchingEngine(e2e_base.TestEndToEnd):

Expand Down Expand Up @@ -283,6 +288,16 @@ def test_create_get_list_matching_engine_index(self, shared_state):

# assert results[0][0].id == 870

# TODO: Test `my_index_endpoint.match` with filter.
# This requires uploading a new content of the Matching Engine Index to Cloud Storage.
# results = my_index_endpoint.match(
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
# queries=[_TEST_MATCH_QUERY],
# num_neighbors=1,
# filter=_TEST_FILTER,
# )
# assert results[0][0].id == 9999

# Undeploy index
my_index_endpoint = my_index_endpoint.undeploy_index(
deployed_index_id=deployed_index.id
Expand Down
75 changes: 75 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
)
from google.cloud.aiplatform.compat.types import (
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
index_endpoint as gca_index_endpoint,
Expand All @@ -37,6 +41,8 @@

from google.protobuf import field_mask_pb2

import grpc

import pytest

# project
Expand Down Expand Up @@ -210,6 +216,9 @@
]
]
_TEST_NUM_NEIGHBOURS = 1
_TEST_FILTER = [
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
]


def uuid_mock():
Expand Down Expand Up @@ -380,6 +389,33 @@ def create_index_endpoint_mock():
yield create_index_endpoint_mock


@pytest.fixture
def index_endpoint_match_queries_mock():
with patch.object(
grpc._channel._UnaryUnaryMultiCallable,
"__call__",
) as index_endpoint_match_queries_mock:
index_endpoint_match_queries_mock.return_value = (
match_service_pb2.BatchMatchResponse(
responses=[
match_service_pb2.BatchMatchResponse.BatchMatchResponsePerIndex(
deployed_index_id="1",
responses=[
match_service_pb2.MatchResponse(
neighbor=[
match_service_pb2.MatchResponse.Neighbor(
id="1", distance=0.1
)
]
)
],
)
]
)
)
yield index_endpoint_match_queries_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestMatchingEngineIndexEndpoint:
def setup_method(self):
Expand Down Expand Up @@ -617,3 +653,42 @@ def test_delete_index_endpoint_with_force(
delete_index_endpoint_mock.assert_called_once_with(
name=_TEST_INDEX_ENDPOINT_NAME
)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
aiplatform.init(project=_TEST_PROJECT)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_index_endpoint.match(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
)

batch_request = match_service_pb2.BatchMatchRequest(
requests=[
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
requests=[
match_service_pb2.MatchRequest(
num_neighbors=_TEST_NUM_NEIGHBOURS,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
float_val=_TEST_QUERIES[0],
restricts=[
match_service_pb2.Namespace(
name="class",
allow_tokens=["token_1"],
deny_tokens=["token_2"],
)
],
)
],
)
]
)

index_endpoint_match_queries_mock.assert_called_with(batch_request)

0 comments on commit d591d3e

Please sign in to comment.