Skip to content

Commit

Permalink
Merge branch 'main' into owl-bot-copy
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg authored Dec 9, 2023
2 parents 920e459 + c9f7119 commit ece01ed
Show file tree
Hide file tree
Showing 23 changed files with 796 additions and 37 deletions.
3 changes: 2 additions & 1 deletion google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,10 @@ def project(self) -> str:
project_not_found_exception_str = (
"Unable to find your project. Please provide a project ID by:"
"\n- Passing a constructor argument"
"\n- Using aiplatform.init()"
"\n- Using vertexai.init()"
"\n- Setting project using 'gcloud config set project my-project'"
"\n- Setting a GCP environment variable"
"\n- To create a Google Cloud project, please follow guidance at https://developers.google.com/workspace/guides/create-project"
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ message MatchRequest {
// The list of restricts.
repeated Namespace restricts = 4;

//The list of numeric restricts.
repeated NumericNamespace numeric_restricts = 11;

// Crowding is a constraint on a neighbor list produced by nearest neighbor
// search requiring that no more than some value k' of the k neighbors
// returned have the same value of crowding_attribute.
Expand Down Expand Up @@ -88,6 +91,9 @@ message Embedding {
// The list of restricts.
repeated Namespace restricts = 3;

// The list of numeric restricts.
repeated NumericNamespace numeric_restricts = 5;

// The attribute value used for crowding. The maximum number of neighbors
// to return per crowding attribute value
// (per_crowding_attribute_num_neighbors) is configured per-query.
Expand Down Expand Up @@ -175,6 +181,7 @@ message BatchMatchResponse {

// Namespace specifies the rules for determining the datapoints that are
// eligible for each matching query, overall query is an AND across namespaces.
// This uses categorical tokens.
message Namespace {
// The string name of the namespace that this proto is specifying,
// such as "color", "shape", "geo", or "tags".
Expand All @@ -192,4 +199,53 @@ message Namespace {
// query will match datapoints that are red or blue, but if those points are
// also purple, then they will be excluded even if they are red/blue.
repeated string deny_tokens = 3;
}
}

// NumericNamespace specifies the rules for determining the datapoints that are
// eligible for each matching query, overall query is an AND across namespaces.
// This uses numeric comparisons.
message NumericNamespace {

// The string name of the namespace that this proto is specifying,
// such as "size" or "cost".
string name = 1;

// The type of Value must be consistent for all datapoints with a given
// namespace name. This is verified at runtime.
oneof Value {
// Represents 64 bit integer.
int64 value_int = 2;
// Represents 32 bit float.
float value_float = 3;
// Represents 64 bit float.
double value_double = 4;
}

// Which comparison operator to use. Should be specified for queries only;
// specifying this for a datapoint is an error.
//
// Datapoints for which Operator is true relative to the query's Value
// field will be allowlisted.
enum Operator {
// Default value of the enum.
OPERATOR_UNSPECIFIED = 0;

// Datapoints are eligible iff their value is < the query's.
LESS = 1;

// Datapoints are eligible iff their value is <= the query's.
LESS_EQUAL = 2;

// Datapoints are eligible iff their value is == the query's.
EQUAL = 3;

// Datapoints are eligible iff their value is >= the query's.
GREATER_EQUAL = 4;

// Datapoints are eligible iff their value is > the query's.
GREATER = 5;
}

// Which comparison operator to use.
Operator op = 5;
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def __init__(
)
self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)

self._public_match_client = None
if self.public_endpoint_domain_name:
self._public_match_client = self._instantiate_public_match_client()

Expand Down Expand Up @@ -518,6 +519,36 @@ def _instantiate_public_match_client(
api_path_override=self.public_endpoint_domain_name,
)

def _instantiate_private_match_service_stub(
self,
deployed_index_id: str,
) -> match_service_pb2_grpc.MatchServiceStub:
"""Helper method to instantiate private match service stub.
Args:
deployed_index_id (str):
Required. The user specified ID of the
DeployedIndex.
Returns:
stub (match_service_pb2_grpc.MatchServiceStub):
Initialized match service stub.
"""
# Find the deployed index by id
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]

if not deployed_indexes:
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")

# Retrieve server ip from deployed index
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address

# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(server_ip))
return match_service_pb2_grpc.MatchServiceStub(channel)

@property
def public_endpoint_domain_name(self) -> Optional[str]:
"""Public endpoint DNS name."""
Expand Down Expand Up @@ -1233,7 +1264,8 @@ def read_index_datapoints(
deployed_index_id: str,
ids: List[str] = [],
) -> List[gca_index_v1beta1.IndexDatapoint]:
"""Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint.
"""Reads the datapoints/vectors of the given IDs on the specified
deployed index which is deployed to public or private endpoint.
```
Example Usage:
Expand All @@ -1252,9 +1284,25 @@ def read_index_datapoints(
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
"""
if not self._public_match_client:
raise ValueError(
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
# Call private match service stub with BatchGetEmbeddings request
response = self._batch_get_embeddings(
deployed_index_id=deployed_index_id, ids=ids
)
return [
gca_index_v1beta1.IndexDatapoint(
datapoint_id=embedding.id,
feature_vector=embedding.float_val,
restricts=gca_index_v1beta1.IndexDatapoint.Restriction(
namespace=embedding.restricts.name,
allow_list=embedding.restricts.allow_tokens,
),
deny_list=embedding.restricts.deny_tokens,
crowding_attributes=gca_index_v1beta1.CrowdingEmbedding(
str(embedding.crowding_tag)
),
)
for embedding in response.embeddings
]

# Create the ReadIndexDatapoints request
read_index_datapoints_request = (
Expand All @@ -1273,6 +1321,38 @@ def read_index_datapoints(
# Wrap the results and return
return response.datapoints

def _batch_get_embeddings(
self,
*,
deployed_index_id: str,
ids: List[str] = [],
) -> List[List[match_service_pb2.Embedding]]:
"""
Reads the datapoints/vectors of the given IDs on the specified index
which is deployed to private endpoint.
Args:
deployed_index_id (str):
Required. The ID of the DeployedIndex to match the queries against.
ids (List[str]):
Required. IDs of the datapoints to be searched for.
Returns:
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
"""
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
)

# Create the batch get embeddings request
batch_request = match_service_pb2.BatchGetEmbeddingsRequest()
batch_request.deployed_index_id = deployed_index_id

for id in ids:
batch_request.id.append(id)
response = stub.BatchGetEmbeddings(batch_request)

return response.embeddings

def match(
self,
deployed_index_id: str,
Expand Down Expand Up @@ -1310,23 +1390,9 @@ def match(
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""

# Find the deployed index by id
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]

if not deployed_indexes:
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")

# Retrieve server ip from deployed index
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address

# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(server_ip))
stub = match_service_pb2_grpc.MatchServiceStub(channel)
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
)

# Create the batch match request
batch_request = match_service_pb2.BatchMatchRequest()
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/aiplatform/preview/vertex_ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def __init__(self, address: Optional[str]) -> None:
" failed to start Head node properly because custom service account isn't supported.",
)
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
cluster = _gapic_utils.persistent_resource_to_cluster(
persistent_resource=self.response
)
if cluster is None:
raise ValueError(
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
)
local_ray_verion = _validation_utils.get_local_ray_version()
if cluster.ray_version != local_ray_verion:
raise ValueError(
f"[Ray on Vertex AI]: Local runtime has Ray version {local_ray_verion}, but the cluster runtime has {cluster.ray_version}. Please ensure that the Ray versions match."
)
super().__init__(address)

def connect(self) -> _VertexRayClientContext:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import google.auth
import google.auth.transport.requests
import logging
import ray
import re

from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -68,6 +69,13 @@ def maybe_reconstruct_resource_name(address) -> str:
return address


def get_local_ray_version():
ray_version = ray.__version__.split(".")
if len(ray_version) == 3:
ray_version = ray_version[:2]
return "_".join(ray_version)


def get_image_uri(ray_version, python_version, enable_cuda):
"""Image uri for a given ray version and python version."""
if ray_version not in ["2_4"]:
Expand Down
49 changes: 49 additions & 0 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,3 +1204,52 @@ def mock_autolog():
with patch.object(aiplatform, "autolog") as mock_autolog_method:
mock_autolog_method.return_value = None
yield mock_autolog_method


"""
----------------------------------------------------------------------------
Vector Search Fixtures
----------------------------------------------------------------------------
"""


@pytest.fixture
def mock_index():
mock = MagicMock(aiplatform.MatchingEngineIndex)
yield mock


@pytest.fixture
def mock_index_endpoint():
mock = MagicMock(aiplatform.MatchingEngineIndexEndpoint)
yield mock


@pytest.fixture
def mock_index_init(mock_index):
with patch.object(aiplatform, "MatchingEngineIndex") as mock:
mock.return_value = mock_index
yield mock


@pytest.fixture
def mock_index_upsert_datapoints(mock_index):
with patch.object(mock_index, "upsert_datapoints") as mock_upsert:
mock_upsert.return_value = None
yield mock_upsert


@pytest.fixture
def mock_index_endpoint_init(mock_index_endpoint):
with patch.object(aiplatform, "MatchingEngineIndexEndpoint") as mock:
mock.return_value = mock_index_endpoint
yield mock


@pytest.fixture
def mock_index_endpoint_find_neighbors(mock_index_endpoint):
with patch.object(
mock_index_endpoint, "find_neighbors"
) as mock_find_neighbors:
mock_find_neighbors.return_value = None
yield mock_find_neighbors
9 changes: 9 additions & 0 deletions samples/model-builder/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,12 @@
)
TENSORBOARD_EXPERIMENT_NAME = "my-tensorboard-experiment"
TENSORBOARD_PLUGIN_PROFILE_NAME = "profile"

# Vector Search
VECTOR_SEARCH_INDEX = "123"
VECTOR_SERACH_INDEX_DATAPOINTS = [
{"datapoint_id": "datapoint_id_1", "feature_vector": [0.1]}
]
VECTOR_SEARCH_INDEX_ENDPOINT = "456"
VECTOR_SEARCH_DEPLOYED_INDEX_ID = "789"
VECTOR_SERACH_INDEX_QUERIES = [[0.1]]
Loading

0 comments on commit ece01ed

Please sign in to comment.