Skip to content

Commit

Permalink
feat: Add support for self-signed JWT for queries on private endpoints
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689941402
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Oct 25, 2024
1 parent 91c2120 commit 5025d03
Show file tree
Hide file tree
Showing 10 changed files with 504 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ def find_neighbors(
return_full_datapoint: bool = False,
numeric_filter: Optional[List[NumericNamespace]] = None,
embedding_ids: Optional[List[str]] = None,
signed_jwt: Optional[str] = None,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index which is deployed to either public or private
Expand Down Expand Up @@ -1456,6 +1457,9 @@ def find_neighbors(
`embedding_ids` to lookup embedding values from dataset, if embedding
with `embedding_ids` exists in the dataset, do nearest neighbor search.
signed_jwt (str):
Optional. A signed JWT for accessing the private endpoint.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""
Expand All @@ -1471,6 +1475,7 @@ def find_neighbors(
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
numeric_filter=numeric_filter,
signed_jwt=signed_jwt,
)

# Create the FindNeighbors request
Expand Down Expand Up @@ -1570,6 +1575,7 @@ def read_index_datapoints(
*,
deployed_index_id: str,
ids: List[str] = [],
signed_jwt: Optional[str] = None,
) -> List[gca_index_v1beta1.IndexDatapoint]:
"""Reads the datapoints/vectors of the given IDs on the specified
deployed index which is deployed to public or private endpoint.
Expand All @@ -1587,6 +1593,8 @@ def read_index_datapoints(
Required. The ID of the DeployedIndex to match the queries against.
ids (List[str]):
Required. IDs of the datapoints to be searched for.
signed_jwt (str):
Optional. A signed JWT for accessing the private endpoint.
Returns:
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
"""
Expand All @@ -1595,6 +1603,7 @@ def read_index_datapoints(
embeddings = self._batch_get_embeddings(
deployed_index_id=deployed_index_id,
ids=ids,
signed_jwt=signed_jwt,
)

response = []
Expand Down Expand Up @@ -1641,6 +1650,7 @@ def _batch_get_embeddings(
*,
deployed_index_id: str,
ids: List[str] = [],
signed_jwt: Optional[str] = None,
) -> List[match_service_pb2.Embedding]:
"""
Reads the datapoints/vectors of the given IDs on the specified index
Expand All @@ -1651,6 +1661,8 @@ def _batch_get_embeddings(
Required. The ID of the DeployedIndex to match the queries against.
ids (List[str]):
Required. IDs of the datapoints to be searched for.
signed_jwt:
Optional. A signed JWT for accessing the private endpoint.
Returns:
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
"""
Expand All @@ -1665,7 +1677,10 @@ def _batch_get_embeddings(

for id in ids:
batch_request.id.append(id)
response = stub.BatchGetEmbeddings(batch_request)
metadata = None
if signed_jwt:
metadata = (("authorization", f"Bearer: {signed_jwt}"),)
response = stub.BatchGetEmbeddings(batch_request, metadata=metadata)

return response.embeddings

Expand All @@ -1680,6 +1695,7 @@ def match(
fraction_leaf_nodes_to_search_override: Optional[float] = None,
low_level_batch_size: int = 0,
numeric_filter: Optional[List[NumericNamespace]] = None,
signed_jwt: Optional[str] = None,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index for private endpoint only.
Expand Down Expand Up @@ -1729,6 +1745,8 @@ def match(
results. For example:
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
will match datapoints that its cost is greater than 5.
signed_jwt (str):
Optional. A signed JWT for accessing the private endpoint.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
Expand Down Expand Up @@ -1809,7 +1827,10 @@ def match(
batch_request.requests.append(batch_request_for_index)

# Perform the request
response = stub.BatchMatch(batch_request)
metadata = None
if signed_jwt:
metadata = (("authorization", f"Bearer: {signed_jwt}"),)
response = stub.BatchMatch(batch_request, metadata=metadata)

# Wrap the results in MatchNeighbor objects and return
match_neighbors_response = []
Expand Down
14 changes: 14 additions & 0 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,20 @@ def mock_index_endpoint_find_neighbors(mock_index_endpoint):
yield mock_find_neighbors


@pytest.fixture
def mock_index_endpoint_match(mock_index_endpoint):
with patch.object(mock_index_endpoint, "match") as mock:
mock.return_value = None
yield mock


@pytest.fixture
def mock_index_endpoint_read_index_datapoints(mock_index_endpoint):
with patch.object(mock_index_endpoint, "read_index_datapoints") as mock:
mock.return_value = None
yield mock


@pytest.fixture
def mock_index_create_tree_ah_index(mock_index):
with patch.object(
Expand Down
1 change: 1 addition & 0 deletions samples/model-builder/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,4 @@
VECTOR_SEARCH_INDEX_LABELS = {"my_key": "my_value"}
VECTOR_SEARCH_GCS_URI = "gs://fake-dir"
VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME = "my-vector-search-index-endpoint"
VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT = "fake-signed-jwt"
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,53 @@ def vector_search_find_neighbors(
print(hybrid_resp)

# [END aiplatform_sdk_vector_search_find_neighbors_sample]


# [START aiplatform_sdk_vector_search_find_neighbors_jwt_sample]
def vector_search_find_neighbors_jwt(
project: str,
location: str,
index_endpoint_name: str,
deployed_index_id: str,
queries: List[List[float]],
num_neighbors: int,
signed_jwt: str,
) -> List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]]:
"""Query the vector search index.
Args:
project (str): Required. Project ID
location (str): Required. The region name
index_endpoint_name (str): Required. Index endpoint to run the query
against.
deployed_index_id (str): Required. The ID of the DeployedIndex to run
the queries against.
queries (List[List[float]]): Required. A list of queries. Each query is
a list of floats, representing a single embedding.
num_neighbors (int): Required. The number of neighbors to return.
signed_jwt (str): Required. The signed JWT token for the private
endpoint. The endpoint must be configured to accept tokens from JWT's
issuer and encoded audience.
Returns:
List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]] - A list of nearest neighbors for each query.
"""
# Initialize the Vertex AI client
aiplatform.init(project=project, location=location)

# Create the index endpoint instance from an existing endpoint.
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=index_endpoint_name
)

# Query the index endpoint for the nearest neighbors.
resp = my_index_endpoint.find_neighbors(
deployed_index_id=deployed_index_id,
queries=queries,
num_neighbors=num_neighbors,
signed_jwt=signed_jwt,
)
return resp

# [END aiplatform_sdk_vector_search_find_neighbors_jwt_sample]

Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,34 @@ def test_vector_search_find_neighbors_sample(
],
any_order=False,
)


def test_vector_search_find_neighbors_jwt_sample(
mock_sdk_init, mock_index_endpoint_init, mock_index_endpoint_find_neighbors
):
vector_search_find_neighbors_sample.vector_search_find_neighbors_jwt(
project=constants.PROJECT,
location=constants.LOCATION,
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT,
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
num_neighbors=10,
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
)

# Check client initialization
mock_sdk_init.assert_called_with(
project=constants.PROJECT, location=constants.LOCATION
)

# Check index endpoint initialization with right index endpoint name
mock_index_endpoint_init.assert_called_with(
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)

# Check index_endpoint.find_neighbors is called with right params.
mock_index_endpoint_find_neighbors.assert_called_with(
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
num_neighbors=10,
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
)
67 changes: 67 additions & 0 deletions samples/model-builder/vector_search/vector_search_match_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from google.cloud import aiplatform


# [START aiplatform_sdk_vector_search_match_jwt_sample]
def vector_search_match_jwt(
project: str,
location: str,
index_endpoint_name: str,
deployed_index_id: str,
queries: List[List[float]],
num_neighbors: int,
signed_jwt: str,
) -> List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]]:
"""Query the vector search index.
Args:
project (str): Required. Project ID
location (str): Required. The region name
index_endpoint_name (str): Required. Index endpoint to run the query
against. The endpoint must be a private endpoint.
deployed_index_id (str): Required. The ID of the DeployedIndex to run
the queries against.
queries (List[List[float]]): Required. A list of queries. Each query is
a list of floats, representing a single embedding.
num_neighbors (int): Required. The number of neighbors to return.
signed_jwt (str): Required. The signed JWT token for the private
endpoint. The endpoint must be configured to accept tokens from JWT's
issuer and encoded audience.
Returns:
List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]] - A list of nearest neighbors for each query.
"""
# Initialize the Vertex AI client
aiplatform.init(project=project, location=location)

# Create the index endpoint instance from an existing endpoint.
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=index_endpoint_name
)

# Query the index endpoint for matches.
resp = my_index_endpoint.match(
deployed_index_id=deployed_index_id,
queries=queries,
num_neighbors=num_neighbors,
signed_jwt=signed_jwt,
)
return resp

# [END aiplatform_sdk_vector_search_match_jwt_sample]

Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import test_constants as constants
from vector_search import vector_search_match_sample


def test_vector_search_match_jwt_sample(
mock_sdk_init, mock_index_endpoint_init, mock_index_endpoint_match
):
vector_search_match_sample.vector_search_match_jwt(
project=constants.PROJECT,
location=constants.LOCATION,
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT,
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
num_neighbors=10,
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
)

# Check client initialization
mock_sdk_init.assert_called_with(
project=constants.PROJECT, location=constants.LOCATION
)

# Check index endpoint initialization with right index endpoint name
mock_index_endpoint_init.assert_called_with(
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)

# Check index_endpoint.match is called with right params.
mock_index_endpoint_match.assert_called_with(
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
num_neighbors=10,
signed_jwt=constants.VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT,
)
Loading

0 comments on commit 5025d03

Please sign in to comment.