diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index b618c63677..385eda9d56 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -18,7 +18,7 @@ from concurrent import futures import logging -import pkg_resources # Note this is used after copybara replacement +import pkg_resources # noqa: F401 # Note this is used after copybara replacement import os from typing import List, Optional, Type, TypeVar, Union @@ -395,6 +395,7 @@ def create_client( api_base_path_override: Optional[str] = None, api_path_override: Optional[str] = None, appended_user_agent: Optional[List[str]] = None, + appended_gapic_version: Optional[str] = None, ) -> _TVertexAiServiceClientWithOverride: """Instantiates a given VertexAiServiceClient with optional overrides. @@ -411,6 +412,8 @@ def create_client( appended_user_agent (List[str]): Optional. User agent appended in the client info. If more than one, it will be separated by spaces. + appended_gapic_version (str): + Optional. GAPIC version suffix appended in the client info. Returns: client: Instantiated Vertex AI Service client with optional overrides """ @@ -422,6 +425,9 @@ def create_client( if appended_user_agent: user_agent = f"{user_agent} {' '.join(appended_user_agent)}" + if appended_gapic_version: + gapic_version = f"{gapic_version}+{appended_gapic_version}" + client_info = gapic_v1.client_info.ClientInfo( gapic_version=gapic_version, user_agent=user_agent, diff --git a/google/cloud/aiplatform/preview/vertex_ray/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/__init__.py new file mode 100644 index 0000000000..adc3226a46 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/__init__.py @@ -0,0 +1,58 @@ +"""Ray on Vertex AI.""" + +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys + +from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasource import ( + BigQueryDatasource, +) +from google.cloud.aiplatform.preview.vertex_ray.client_builder import ( + VertexRayClientBuilder as ClientBuilder, +) + +from google.cloud.aiplatform.preview.vertex_ray.cluster_init import ( + create_ray_cluster, + delete_ray_cluster, + get_ray_cluster, + list_ray_clusters, + update_ray_cluster, +) +from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( + Resources, +) + +from google.cloud.aiplatform.preview.vertex_ray.dashboard_sdk import ( + get_job_submission_client_cluster_info, +) + +if sys.version_info[1] != 10: + print( + "[Ray on Vertex]: The client environment with Python version 3.10 is required." + ) + +__all__ = ( + "BigQueryDatasource", + "ClientBuilder", + "get_job_submission_client_cluster_info", + "create_ray_cluster", + "delete_ray_cluster", + "get_ray_cluster", + "list_ray_clusters", + "update_ray_cluster", + "Resources", +) diff --git a/google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py b/google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py new file mode 100644 index 0000000000..4aa4d7a503 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import tempfile +import time +from typing import Any, Dict, List, Optional +import uuid + +from google.api_core import client_info +from google.api_core import exceptions +from google.api_core.gapic_v1 import client_info as v1_client_info +from google.cloud import bigquery +from google.cloud import bigquery_storage +from google.cloud.aiplatform import initializer +from google.cloud.bigquery_storage import types +import pyarrow.parquet as pq +from ray.data._internal.remote_fn import cached_remote_fn +from ray.data.block import Block +from ray.data.block import BlockAccessor +from ray.data.block import BlockMetadata +from ray.data.datasource.datasource import Datasource +from ray.data.datasource.datasource import Reader +from ray.data.datasource.datasource import ReadTask +from ray.data.datasource.datasource import WriteResult +from ray.types import ObjectRef + + +_BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray" +_BQS_GAPIC_VERSION = bigquery_storage.__version__ + "+vertex_ray" +bq_info = client_info.ClientInfo( + gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}" +) +bqstorage_info = v1_client_info.ClientInfo( + gapic_version=_BQS_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQS_GAPIC_VERSION}" +) + + +class _BigQueryDatasourceReader(Reader): + def __init__( + self, + project_id: Optional[str] = None, + dataset: Optional[str] = None, + query: Optional[str] = None, + parallelism: Optional[int] = -1, + **kwargs: Optional[Dict[str, Any]], + ): + self._project_id = project_id or initializer.global_config.project + self._dataset = dataset + self._query = query + self._kwargs = kwargs + + if query is not None and dataset is not None: + raise ValueError( + "[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)." + ) + + def get_read_tasks(self, parallelism: int) -> List[ReadTask]: + # Executed by a worker node + def _read_single_partition(stream, kwargs) -> Block: + client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info) + reader = client.read_rows(stream.name) + return reader.to_arrow() + + if self._query: + query_client = bigquery.Client( + project=self._project_id, client_info=bq_info + ) + query_job = query_client.query(self._query) + query_job.result() + destination = str(query_job.destination) + dataset_id = destination.split(".")[-2] + table_id = destination.split(".")[-1] + else: + self._validate_dataset_table_exist(self._project_id, self._dataset) + dataset_id = self._dataset.split(".")[0] + table_id = self._dataset.split(".")[1] + + bqs_client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info) + table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}" + + if parallelism == -1: + parallelism = None + requested_session = types.ReadSession( + table=table, + data_format=types.DataFormat.ARROW, + ) + read_session = bqs_client.create_read_session( + parent=f"projects/{self._project_id}", + read_session=requested_session, + max_stream_count=parallelism, + ) + + read_tasks = [] + print("[Ray on Vertex AI]: Created streams:", len(read_session.streams)) + if len(read_session.streams) < parallelism: + print( + "[Ray on Vertex AI]: The number of streams created by the " + + "BigQuery Storage Read API is less than the requested " + + "parallelism due to the size of the dataset." + ) + + for stream in read_session.streams: + # Create a metadata block object to store schema, etc. + metadata = BlockMetadata( + num_rows=None, + size_bytes=None, + schema=None, + input_files=None, + exec_stats=None, + ) + + # Create a no-arg wrapper read function which returns a block + read_single_partition = ( + lambda stream=stream, kwargs=self._kwargs: [ # noqa: F731 + _read_single_partition(stream, kwargs) + ] + ) + + # Create the read task and pass the wrapper and metadata in + read_task = ReadTask(read_single_partition, metadata) + read_tasks.append(read_task) + + return read_tasks + + def estimate_inmemory_data_size(self) -> Optional[int]: + # TODO(b/281891467): Implement this method + return None + + def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None: + client = bigquery.Client(project=project_id, client_info=bq_info) + dataset_id = dataset.split(".")[0] + try: + client.get_dataset(dataset_id) + except exceptions.NotFound: + raise ValueError( + "[Ray on Vertex AI]: Dataset {} is not found. Please ensure that it exists.".format( + dataset_id + ) + ) + + try: + client.get_table(dataset) + except exceptions.NotFound: + raise ValueError( + "[Ray on Vertex AI]: Table {} is not found. Please ensure that it exists.".format( + dataset + ) + ) + + +class BigQueryDatasource(Datasource): + def create_reader(self, **kwargs) -> Reader: + return _BigQueryDatasourceReader(**kwargs) + + def do_write( + self, + blocks: List[ObjectRef[Block]], + metadata: List[BlockMetadata], + ray_remote_args: Optional[Dict[str, Any]], + project_id: Optional[str] = None, + dataset: Optional[str] = None, + ) -> List[ObjectRef[WriteResult]]: + def _write_single_block( + block: Block, metadata: BlockMetadata, project_id: str, dataset: str + ): + print("[Ray on Vertex AI]: Starting to write", metadata.num_rows, "rows") + block = BlockAccessor.for_block(block).to_arrow() + + client = bigquery.Client(project=project_id, client_info=bq_info) + job_config = bigquery.LoadJobConfig(autodetect=True) + job_config.source_format = bigquery.SourceFormat.PARQUET + job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND + + with tempfile.TemporaryDirectory() as temp_dir: + fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet") + pq.write_table(block, fp, compression="SNAPPY") + + retry_cnt = 0 + while retry_cnt < 10: + with open(fp, "rb") as source_file: + job = client.load_table_from_file( + source_file, dataset, job_config=job_config + ) + retry_cnt += 1 + try: + logging.info(job.result()) + break + except exceptions.Forbidden as e: + print( + "[Ray on Vertex AI]: Rate limit exceeded... Sleeping to try again" + ) + logging.debug(e) + time.sleep(11) + print("[Ray on Vertex AI]: Finished writing", metadata.num_rows, "rows") + + project_id = project_id or initializer.global_config.project + + if dataset is None: + raise ValueError( + "[Ray on Vertex AI]: Dataset is required when writing to BigQuery." + ) + + if ray_remote_args is None: + ray_remote_args = {} + + _write_single_block = cached_remote_fn(_write_single_block).options( + **ray_remote_args + ) + write_tasks = [] + + # Set up datasets to write + client = bigquery.Client(project=project_id, client_info=bq_info) + dataset_id = dataset.split(".", 1)[0] + try: + client.create_dataset(f"{project_id}.{dataset_id}", timeout=30) + print("[Ray on Vertex AI]: Created dataset", dataset_id) + except exceptions.Conflict: + print( + "[Ray on Vertex AI]: Dataset", + dataset_id, + "already exists. The table will be overwritten if it already exists.", + ) + + # Delete table if it already exists + client.delete_table(f"{project_id}.{dataset}", not_found_ok=True) + + print("[Ray on Vertex AI]: Writing", len(blocks), "blocks") + for i in range(len(blocks)): + write_task = _write_single_block.remote( + blocks[i], metadata[i], project_id, dataset + ) + write_tasks.append(write_task) + return write_tasks diff --git a/google/cloud/aiplatform/preview/vertex_ray/client_builder.py b/google/cloud/aiplatform/preview/vertex_ray/client_builder.py new file mode 100644 index 0000000000..7e686e4962 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/client_builder.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict +from typing import Optional +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from ray import client_builder +from .render import VertexRayTemplate +from .util import _validation_utils +from .util import _gapic_utils + + +VERTEX_SDK_VERSION = aiplatform.__version__ +_LOGGER = base.Logger(__name__) + + +class _VertexRayClientContext(client_builder.ClientContext): + """Custom ClientContext.""" + + def __init__( + self, + persistent_resource_id: str, + ray_head_uris: Dict[str, str], + ray_client_context: client_builder.ClientContext, + ) -> None: + dashboard_uri = ray_head_uris.get("RAY_DASHBOARD_URI") + if dashboard_uri is None: + raise ValueError( + "Ray Cluster ", + persistent_resource_id, + " failed to start Head node properly.", + ) + + super().__init__( + dashboard_url=dashboard_uri, + python_version=ray_client_context.python_version, + ray_version=ray_client_context.ray_version, + ray_commit=ray_client_context.ray_commit, + protocol_version=ray_client_context.protocol_version, + _num_clients=ray_client_context._num_clients, + _context_to_restore=ray_client_context._context_to_restore, + ) + self.persistent_resource_id = persistent_resource_id + self.vertex_sdk_version = str(VERTEX_SDK_VERSION) + self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI") + + def _repr_html_(self): + shell_uri_row = None + if self.shell_uri is not None: + shell_uri_row = VertexRayTemplate("context_shellurirow.html.j2").render( + shell_uri=self.shell_uri + ) + + return VertexRayTemplate("context.html.j2").render( + python_version=self.python_version, + ray_version=self.ray_version, + vertex_sdk_version=self.vertex_sdk_version, + dashboard_url=self.dashboard_url, + persistent_resource_id=self.persistent_resource_id, + shell_uri_row=shell_uri_row, + ) + + +class VertexRayClientBuilder(client_builder.ClientBuilder): + """Class to initialize a Ray client with vertex on ray capabilities.""" + + def __init__(self, address: Optional[str]) -> None: + address = _validation_utils.maybe_reconstruct_resource_name(address) + _validation_utils.valid_resource_name(address) + + self.vertex_address = address + _LOGGER.info( + "[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API" + ) + + self.resource_name = address + + self.response = _gapic_utils.get_persistent_resource(self.resource_name) + address = self.response.resource_runtime.access_uris.get( + "RAY_HEAD_NODE_INTERNAL_IP" + ) + if address is None: + persistent_resource_id = self.resource_name.split("/")[5] + raise ValueError( + "[Ray on Vertex AI]: Ray Cluster ", + persistent_resource_id, + " failed to start Head node properly.", + ) + # Handling service_account + service_account = ( + self.response.resource_runtime_spec.service_account_spec.service_account + ) + + if service_account: + raise ValueError( + "[Ray on Vertex AI]: Cluster ", + address, + " failed to start Head node properly because custom service account isn't supported.", + ) + _LOGGER.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address) + super().__init__(address) + + def connect(self) -> _VertexRayClientContext: + # Can send any other params to ray cluster here + _LOGGER.info("[Ray on Vertex AI]: Connecting...") + ray_client_context = super().connect() + ray_head_uris = self.response.resource_runtime.access_uris + + # Valid resource name (reference public doc for public release): + # "projects//locations//persistentResources/" + persistent_resource_id = self.resource_name.split("/")[5] + + return _VertexRayClientContext( + persistent_resource_id=persistent_resource_id, + ray_head_uris=ray_head_uris, + ray_client_context=ray_client_context, + ) diff --git a/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py b/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py new file mode 100644 index 0000000000..2a913c3a63 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from typing import List, Optional + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.utils import resource_manager_utils +from google.cloud.aiplatform_v1beta1.types import persistent_resource_service + +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + PersistentResource, + RaySpec, + ResourcePool, + ResourceRuntimeSpec, +) + +from google.cloud.aiplatform.preview.vertex_ray.util import ( + _gapic_utils, + _validation_utils, + resources, +) + +from google.protobuf import field_mask_pb2 # type: ignore + + +def create_ray_cluster( + head_node_type: Optional[resources.Resources] = resources.Resources(), + python_version: Optional[str] = "3_10", + ray_version: Optional[str] = "2_4", + network: Optional[str] = None, + cluster_name: Optional[str] = None, + worker_node_types: Optional[List[resources.Resources]] = None, +) -> str: + """Create a ray cluster on the Vertex AI. + + Sample usage: + + from vertex_ray import Resources + + head_node_type = Resources( + machine_type="n1-standard-4", + node_count=1, + accelerator_type="NVIDIA_TESLA_K80", + accelerator_count=1, + ) + + worker_node_types = [Resources( + machine_type="n1-standard-4", + node_count=2, + accelerator_type="NVIDIA_TESLA_K80", + accelerator_count=1, + )] + + cluster_resource_name = vertex_ray.create_ray_cluster( + head_node_type=head_node_type, + network="my-vpc", + worker_node_types=worker_node_types, + ) + + After a ray cluster is set up, you can call + `ray.init(vertex_ray://{cluster_resource_name}, runtime_env=...)` without + specifying ray cluster address to connect to the cluster. To shut down the + cluster you can call `ray.delete_ray_cluster()`. + Note: If the active ray cluster haven't shut down, you cannot create a new ray + cluster with the same cluster_name. + + Args: + head_node_type: The head node resource. Resources.node_count must be 1. + If not set, default value of Resources() class will be used. + python_version: Python version for the ray cluster. + ray_version: Ray version for the ray cluster. + network: Virtual private cloud (VPC) network. For Ray Client, VPC + peering is required to connect to the Ray Cluster managed in the + Vertex API service. For Ray Job API, VPC network is not required + because Ray Cluster connection can be accessed through dashboard + address. + cluster_name: This value may be up to 63 characters, and valid + characters are `[a-z0-9_-]`. The first character cannot be a number + or hyphen. + worker_node_types: The list of Resources of the worker nodes. The same + Resources object should not appear multiple times in the list. + + Returns: + The cluster_resource_name of the initiated Ray cluster on Vertex. + """ + + if network is None: + raise ValueError( + "[Ray on Vertex]: VPC network is required for client connection." + ) + + if cluster_name is None: + cluster_name = "ray-cluster-" + utils.timestamped_unique_name() + + if head_node_type: + if head_node_type.node_count != 1: + raise ValueError( + "[Ray on Vertex AI]: For head_node_type, " + + "Resources.node_count must be 1." + ) + + resource_pool_images = {} + + # head node + resource_pool_0 = ResourcePool() + resource_pool_0.id = "head-node" + resource_pool_0.replica_count = head_node_type.node_count + resource_pool_0.machine_spec.machine_type = head_node_type.machine_type + resource_pool_0.machine_spec.accelerator_count = head_node_type.accelerator_count + resource_pool_0.machine_spec.accelerator_type = head_node_type.accelerator_type + resource_pool_0.disk_spec.boot_disk_type = head_node_type.boot_disk_type + resource_pool_0.disk_spec.boot_disk_size_gb = head_node_type.boot_disk_size_gb + + enable_cuda = True if head_node_type.accelerator_count > 0 else False + image_uri = _validation_utils.get_image_uri( + ray_version, python_version, enable_cuda + ) + resource_pool_images[resource_pool_0.id] = image_uri + + worker_pools = [] + i = 0 + if worker_node_types: + for worker_node_type in worker_node_types: + # Worker and head share the same MachineSpec, merge them into the + # same ResourcePool + additional_replica_count = resources._check_machine_spec_identical( + head_node_type, worker_node_type + ) + resource_pool_0.replica_count = ( + resource_pool_0.replica_count + additional_replica_count + ) + if additional_replica_count == 0: + resource_pool = ResourcePool() + resource_pool.id = f"worker-pool{i+1}" + resource_pool.replica_count = worker_node_type.node_count + resource_pool.machine_spec.machine_type = worker_node_type.machine_type + resource_pool.machine_spec.accelerator_count = ( + worker_node_type.accelerator_count + ) + resource_pool.machine_spec.accelerator_type = ( + worker_node_type.accelerator_type + ) + resource_pool.disk_spec.boot_disk_type = worker_node_type.boot_disk_type + resource_pool.disk_spec.boot_disk_size_gb = ( + worker_node_type.boot_disk_size_gb + ) + worker_pools.append(resource_pool) + enable_cuda = True if worker_node_type.accelerator_count > 0 else False + image_uri = _validation_utils.get_image_uri( + ray_version, python_version, enable_cuda + ) + resource_pool_images[resource_pool.id] = image_uri + + i += 1 + + resource_pools = [resource_pool_0] + worker_pools + + ray_spec = RaySpec(resource_pool_images=resource_pool_images) + resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec) + persistent_resource = PersistentResource( + resource_pools=resource_pools, + network=network, + resource_runtime_spec=resource_runtime_spec, + ) + + location = initializer.global_config.location + project_id = initializer.global_config.project + project_number = resource_manager_utils.get_project_number(project_id) + + parent = f"projects/{project_number}/locations/{location}" + request = persistent_resource_service.CreatePersistentResourceRequest( + parent=parent, + persistent_resource=persistent_resource, + persistent_resource_id=cluster_name, + ) + + client = _gapic_utils.create_persistent_resource_client() + try: + _ = client.create_persistent_resource(request) + except Exception as e: + raise ValueError("Failed in cluster creation due to: ", e) from e + + # Get persisent resource + cluster_resource_name = f"{parent}/persistentResources/{cluster_name}" + response = _gapic_utils.get_persistent_resource( + persistent_resource_name=cluster_resource_name, + tolerance=1, # allow 1 retry to avoid get request before creation + ) + return response.name + + +def delete_ray_cluster(cluster_resource_name: str) -> None: + """Delete Ray Cluster. + + Args: + cluster_resource_name: Cluster resource name. + Raises: + FailedPrecondition: If the cluster is deleted already. + """ + client = _gapic_utils.create_persistent_resource_client() + request = persistent_resource_service.DeletePersistentResourceRequest( + name=cluster_resource_name + ) + + try: + client.delete_persistent_resource(request) + print("[Ray on Vertex AI]: Successfully deleted the cluster.") + except Exception as e: + raise ValueError( + "[Ray on Vertex AI]: Failed in cluster deletion due to: ", e + ) from e + + +def get_ray_cluster(cluster_resource_name: str) -> resources.Cluster: + """Get Ray Cluster. + + Args: + cluster_resource_name: Cluster resource name. + Returns: + A Cluster object. + """ + client = _gapic_utils.create_persistent_resource_client() + request = persistent_resource_service.GetPersistentResourceRequest( + name=cluster_resource_name + ) + try: + response = client.get_persistent_resource(request) + except Exception as e: + raise ValueError( + "[Ray on Vertex AI]: Failed in getting the cluster due to: ", e + ) from e + + cluster = _gapic_utils.persistent_resource_to_cluster(persistent_resource=response) + if cluster: + return cluster + raise ValueError("[Ray on Vertex AI]: The cluster is not a Ray cluster.") + + +def list_ray_clusters() -> List[resources.Cluster]: + """List Ray Clusters under the currently authenticated project. + + Returns: + List of Cluster objects that exists in the current authorized project. + """ + location = initializer.global_config.location + project_id = initializer.global_config.project + project_number = resource_manager_utils.get_project_number(project_id) + parent = f"projects/{project_number}/locations/{location}" + request = persistent_resource_service.ListPersistentResourcesRequest( + parent=parent, + ) + client = _gapic_utils.create_persistent_resource_client() + try: + response = client.list_persistent_resources(request) + except Exception as e: + raise ValueError( + "[Ray on Vertex AI]: Failed in listing the clusters due to: ", e + ) from e + + ray_clusters = [] + for persistent_resource in response: + ray_cluster = _gapic_utils.persistent_resource_to_cluster( + persistent_resource=persistent_resource + ) + if ray_cluster: + ray_clusters.append(ray_cluster) + + return ray_clusters + + +def update_ray_cluster( + cluster_resource_name: str, worker_node_types: List[resources.Resources] +) -> str: + """Update Ray Cluster (currently support resizing node counts for worker nodes). + + Sample usage: + + my_cluster = vertex_ray.get_ray_cluster( + cluster_resource_name=my_existing_cluster_resource_name, + ) + + # Declaration to resize all the worker_node_type to node_count=1 + new_worker_node_types = [] + for worker_node_type in my_cluster.worker_node_types: + worker_node_type.node_count = 1 + new_worker_node_types.append(worker_node_type) + + # Execution to update new node_count (block until complete) + vertex_ray.update_ray_cluster( + cluster_resource_name=my_cluster.cluster_resource_name, + worker_node_types=new_worker_node_types, + ) + + Args: + cluster_resource_name: + worker_node_types: The list of Resources of the resized worker nodes. + The same Resources object should not appear multiple times in the list. + Returns: + The cluster_resource_name of the Ray cluster on Vertex. + """ + persistent_resource = _gapic_utils.get_persistent_resource( + persistent_resource_name=cluster_resource_name + ) + + current_persistent_resource = copy.deepcopy(persistent_resource) + head_node_type = get_ray_cluster(cluster_resource_name).head_node_type + current_persistent_resource.resource_pools[0].replica_count = 1 + # TODO(b/300146407): Raise ValueError for duplicate resource pools + not_merged = 1 + for i in range(len(worker_node_types)): + additional_replica_count = resources._check_machine_spec_identical( + head_node_type, worker_node_types[i] + ) + if additional_replica_count != 0: + # merge the 1st duplicated worker with head + current_persistent_resource.resource_pools[0].replica_count = ( + 1 + additional_replica_count + ) + # reset not_merged + not_merged = 0 + else: + # No duplication w/ head node, write the 2nd worker node to the 2nd resource pool. + current_persistent_resource.resource_pools[ + i + not_merged + ].replica_count = worker_node_types[i].node_count + + request = persistent_resource_service.UpdatePersistentResourceRequest( + persistent_resource=current_persistent_resource, + update_mask=field_mask_pb2.FieldMask(paths=["resource_pools.replica_count"]), + ) + client = _gapic_utils.create_persistent_resource_client() + try: + operation_future = client.update_persistent_resource(request) + except Exception as e: + raise ValueError( + "[Ray on Vertex AI]: Failed in updating the cluster due to: ", e + ) from e + + # block before returning + response = operation_future.result() + print("[Ray on Vertex AI]: Successfully updated the cluster.") + return response.name diff --git a/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py b/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py new file mode 100644 index 0000000000..2328299893 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Utility to interact with Ray-on-Vertex dashboard.""" + +from .util import _gapic_utils +from .util import _validation_utils +from ray.dashboard.modules import dashboard_sdk as oss_dashboard_sdk + + +def get_job_submission_client_cluster_info( + address: str, *args, **kwargs +) -> oss_dashboard_sdk.ClusterInfo: + """A vertex_ray implementation of get_job_submission_client_cluster_info(). + + Implements + https://github.com/ray-project/ray/blob/ray-2.3.1/dashboard/modules/dashboard_sdk.py#L82 + This will be called in from Ray Job API Python client. + + Args: + address: Address without the module prefix `vertex_ray` but otherwise + the same format as passed to ray.init(address="vertex_ray://..."). + *args: Reminder of positional args that might be passed down from + the framework. + **kwargs: Reminder of keyword args that might be passed down from + the framework. + + Returns: + An instance of ClusterInfo that contains address, cookies and + metadata for SubmissionClient to use. + + Raises: + RuntimeError if head_address is None. + """ + address = _validation_utils.maybe_reconstruct_resource_name(address) + _validation_utils.valid_resource_name(address) + + resource_name = address + response = _gapic_utils.get_persistent_resource(resource_name) + head_address = response.resource_runtime.access_uris.get( + "RAY_HEAD_NODE_INTERNAL_IP", None + ) + if head_address is None: + raise RuntimeError( + "[Ray on Vertex AI]: Unable to obtain a response from the backend." + ) + + # Assume that head node internal IP in a form of xxx.xxx.xxx.xxx:10001. + # Ray-on-Vertex cluster serves the Dashboard at port 8888 instead of + # the default 8251. + head_address = ":".join([head_address.split(":")[0], "8888"]) + + return oss_dashboard_sdk.get_job_submission_client_cluster_info( + address=head_address, *args, **kwargs + ) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/__init__.py new file mode 100644 index 0000000000..8f74684bc7 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/__init__.py @@ -0,0 +1,18 @@ +"""Ray on Vertex AI Prediction.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/__init__.py new file mode 100644 index 0000000000..856fc73fe7 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import register_sklearn + +__all__ = ("register_sklearn",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py b/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py new file mode 100644 index 0000000000..57f637049a --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py @@ -0,0 +1,125 @@ +"""Regsiter Scikit Learn for Ray on Vertex AI.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import pickle +import tempfile +from typing import Optional, TYPE_CHECKING + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.utils import gcs_utils +from google.cloud.aiplatform.preview.vertex_ray.predict.util import constants +from google.cloud.aiplatform.preview.vertex_ray.predict.util import ( + predict_utils, +) + + +try: + from ray.train import sklearn as ray_sklearn + + if TYPE_CHECKING: + import sklearn + +except ModuleNotFoundError as mnfe: + raise ModuleNotFoundError("Sklearn isn't installed.") from mnfe + + +def register_sklearn( + checkpoint: ray_sklearn.SklearnCheckpoint, + artifact_uri: Optional[str] = None, + **kwargs, +) -> aiplatform.Model: + """Uploads a Ray Sklearn Checkpoint as Sklearn Model to Model Registry. + + Example usage: + from vertex_ray.predict import sklearn + from ray.train.sklearn import SklearnCheckpoint + + trainer = SklearnTrainer(estimator=RandomForestClassifier, ...) + result = trainer.fit() + sklearn_checkpoint = SklearnCheckpoint.from_checkpoint(result.checkpoint) + + my_model = sklearn.register_sklearn( + checkpoint=sklearn_checkpoint, + artifact_uri="gs://{gcs-bucket-name}/path/to/store" + ) + + + Args: + checkpoint: SklearnCheckpoint instance. + artifact_uri (str): + The path to the directory where Model Artifacts will be saved. If + not set, will use staging bucket set in aiplatform.init(). + **kwargs: + Any kwargs will be passed to aiplatform.Model registration. + + Returns: + model (aiplatform.Model): + Instantiated representation of the uploaded model resource. + + Raises: + ValueError: Invalid Argument. + """ + artifact_uri = artifact_uri or initializer.global_config.staging_bucket + predict_utils.validate_artifact_uri(artifact_uri) + display_model_name = ( + f"ray-on-vertex-registered-sklearn-model-{utils.timestamped_unique_name()}" + ) + estimator = _get_estimator_from(checkpoint) + + model_dir = os.path.join(artifact_uri, display_model_name) + file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME) + + with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file: + pickle.dump(estimator, temp_file) + gcs_utils.upload_to_gcs(temp_file.name, file_path) + return aiplatform.Model.upload_scikit_learn_model_file( + model_file_path=temp_file.name, display_name=display_model_name, **kwargs + ) + + +def _get_estimator_from( + checkpoint: ray_sklearn.SklearnCheckpoint, +) -> "sklearn.base.BaseEstimator": + """Converts a SklearnCheckpoint to sklearn estimator. + + Args: + checkpoint: SklearnCheckpoint instance. + + Returns: + A Sklearn BaseEstimator + + Raises: + ValueError: Invalid Argument. + """ + if not isinstance(checkpoint, ray_sklearn.SklearnCheckpoint): + raise ValueError( + "[Ray on Vertex AI]: arg checkpoint should be a" + " ray.train.sklearn.SklearnCheckpoint instance" + ) + if checkpoint.get_preprocessor() is not None: + logging.warning( + "Checkpoint contains preprocessor. However, converting from a Ray" + " Checkpoint to framework specific model does NOT support" + " preprocessing. The model will be exported without preprocessors." + ) + return checkpoint.get_estimator() diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/__init__.py new file mode 100644 index 0000000000..a67539b753 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import register_tensorflow + +__all__ = ("register_tensorflow",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py b/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py new file mode 100644 index 0000000000..eb41967f2f --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py @@ -0,0 +1,128 @@ +"""Regsiter Tensorflow for Ray on Vertex AI.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import logging +from typing import Callable, Optional, Union, TYPE_CHECKING + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.preview.vertex_ray.predict.util import ( + predict_utils, +) + + +try: + from ray.train import tensorflow as ray_tensorflow + + if TYPE_CHECKING: + import tensorflow as tf + +except ModuleNotFoundError as mnfe: + raise ModuleNotFoundError("Tensorflow isn't installed.") from mnfe + + +def register_tensorflow( + checkpoint: ray_tensorflow.TensorflowCheckpoint, + artifact_uri: Optional[str] = None, + _model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None, + **kwargs, +) -> aiplatform.Model: + """Uploads a Ray Tensorflow Checkpoint as Tensorflow Model to Model Registry. + + Example usage: + from vertex_ray.predict import tensorflow + + def create_model(): + model = tf.keras.Sequential(...) + ... + return model + + result = trainer.fit() + my_model = tensorflow.register_tensorflow( + checkpoint=result.checkpoint, + _model=create_model, + artifact_uri="gs://{gcs-bucket-name}/path/to/store", + use_gpu=True + ) + + 1. `use_gpu` will be passed to aiplatform.Model.upload_tensorflow_saved_model() + 2. The `create_model` provides the model_definition which is required if + you create the TensorflowCheckpoint using `from_model` method. + More here, https://docs.ray.io/en/latest/train/api/doc/ray.train.tensorflow.TensorflowCheckpoint.get_model.html#ray.train.tensorflow.TensorflowCheckpoint.get_model + + Args: + checkpoint: TensorflowCheckpoint instance. + artifact_uri (str): + The path to the directory where Model Artifacts will be saved. If + not set, will use staging bucket set in aiplatform.init(). + _model: Tensorflow Model Definition. Refer + https://docs.ray.io/en/latest/train/api/doc/ray.train.tensorflow.TensorflowCheckpoint.get_model.html#ray.train.tensorflow.TensorflowCheckpoint.get_model + **kwargs: + Any kwargs will be passed to aiplatform.Model registration. + + Returns: + model (aiplatform.Model): + Instantiated representation of the uploaded model resource. + + Raises: + ValueError: Invalid Argument. + """ + artifact_uri = artifact_uri or initializer.global_config.staging_bucket + predict_utils.validate_artifact_uri(artifact_uri) + prefix = "ray-on-vertex-registered-tensorflow-model" + display_model_name = f"{prefix}-{utils.timestamped_unique_name()}" + tf_model = _get_tensorflow_model_from(checkpoint, model=_model) + model_dir = os.path.join(artifact_uri, prefix) + tf_model.save(model_dir) + return aiplatform.Model.upload_tensorflow_saved_model( + saved_model_dir=model_dir, + display_name=display_model_name, + **kwargs, + ) + + +def _get_tensorflow_model_from( + checkpoint: ray_tensorflow.TensorflowCheckpoint, + model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None, +) -> "tf.keras.Model": + """Converts a TensorflowCheckpoint to Tensorflow Model. + + Args: + checkpoint: TensorflowCheckpoint instance. + model: Tensorflow Model Defination. + + Returns: + A Tensorflow Native Framework Model. + + Raises: + ValueError: Invalid Argument. + """ + if not isinstance(checkpoint, ray_tensorflow.TensorflowCheckpoint): + raise ValueError( + "[Ray on Vertex AI]: arg checkpoint should be a" + " ray.train.tensorflow.TensorflowCheckpoint instance" + ) + if checkpoint.get_preprocessor() is not None: + logging.warning( + "Checkpoint contains preprocessor. However, converting from a Ray" + " Checkpoint to framework specific model does NOT support" + " preprocessing. The model will be exported without preprocessors." + ) + return checkpoint.get_model(model) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/torch/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/torch/__init__.py new file mode 100644 index 0000000000..175fcd90fa --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/torch/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import get_pytorch_model_from + +__all__ = ("get_pytorch_model_from",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py b/google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py new file mode 100644 index 0000000000..ebc913ecab --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py @@ -0,0 +1,65 @@ +"""Regsiter Torch for Ray on Vertex AI.""" +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional + +try: + from ray.train import torch as ray_torch + import torch +except ModuleNotFoundError as mnfe: + raise ModuleNotFoundError("Torch isn't installed.") from mnfe + + +def get_pytorch_model_from( + checkpoint: ray_torch.TorchCheckpoint, + model: Optional[torch.nn.Module] = None, +) -> torch.nn.Module: + """Converts a TorchCheckpoint to Pytorch Model. + + Example: + from vertex_ray.predict import torch + result = TorchTrainer.fit(...) + + pytorch_model = torch.get_pytorch_model_from( + checkpoint=result.checkpoint + ) + + Args: + checkpoint: TorchCheckpoint instance. + model: If the checkpoint contains a model state dict, and not the model + itself, then the state dict will be loaded to this `model`. Otherwise, + the model will be discarded. + + Returns: + A Pytorch Native Framework Model. + + Raises: + ValueError: Invalid Argument. + """ + if not isinstance(checkpoint, ray_torch.TorchCheckpoint): + raise ValueError( + "[Ray on Vertex AI]: arg checkpoint should be a" + " ray.train.torch.TorchCheckpoint instance" + ) + if checkpoint.get_preprocessor() is not None: + logging.warning( + "Checkpoint contains preprocessor. However, converting from a Ray" + " Checkpoint to framework specific model does NOT support" + " preprocessing. The model will be exported without preprocessors." + ) + return checkpoint.get_model(model=model) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/util/constants.py b/google/cloud/aiplatform/preview/vertex_ray/predict/util/constants.py new file mode 100644 index 0000000000..019d0cddc7 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/util/constants.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Constants.""" + +# Required Names for model files are specified here +# https://cloud.google.com/vertex-ai/docs/training/exporting-model-artifacts#framework-specific_requirements +_PICKLE_FILE_NAME = "model.pkl" +_PICKLE_EXTENTION = ".pkl" + +_XGBOOST_VERSION = "1.6" diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/util/predict_utils.py b/google/cloud/aiplatform/preview/vertex_ray/predict/util/predict_utils.py new file mode 100644 index 0000000000..a18fe4cabd --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/util/predict_utils.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Predict Utils. +""" + + +def validate_artifact_uri(artifact_uri: str) -> None: + if artifact_uri is None or not artifact_uri.startswith("gs://"): + raise ValueError("Argument 'artifact_uri' should start with 'gs://'.") diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/__init__.py b/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/__init__.py new file mode 100644 index 0000000000..d98b638879 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/__init__.py @@ -0,0 +1,22 @@ +"""Ray on Vertex AI Prediction Tensorflow.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .register import register_xgboost + +__all__ = ("register_xgboost",) diff --git a/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py b/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py new file mode 100644 index 0000000000..ff4ab2fa04 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py @@ -0,0 +1,128 @@ +"""Regsiter XGBoost for Ray on Vertex AI.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import pickle +import tempfile +from typing import Optional, TYPE_CHECKING + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.utils import gcs_utils +from google.cloud.aiplatform.preview.vertex_ray.predict.util import constants +from google.cloud.aiplatform.preview.vertex_ray.predict.util import ( + predict_utils, +) + + +try: + from ray.train import xgboost as ray_xgboost + + if TYPE_CHECKING: + import xgboost + +except ModuleNotFoundError as mnfe: + raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe + + +def register_xgboost( + checkpoint: "ray_xgboost.XGBoostCheckpoint", + artifact_uri: Optional[str] = None, + **kwargs, +) -> aiplatform.Model: + """Uploads a Ray XGBoost Checkpoint as XGBoost Model to Model Registry. + + Example usage: + from vertex_ray.predict import xgboost + from ray.train.xgboost import XGBoostCheckpoint + + trainer = XGBoostTrainer(...) + result = trainer.fit() + xgboost_checkpoint = XGBoostCheckpoint.from_checkpoint(result.checkpoint) + + my_model = xgboost.register_xgboost( + checkpoint=xgboost_checkpoint, + artifact_uri="gs://{gcs-bucket-name}/path/to/store" + ) + + + Args: + checkpoint: XGBoostCheckpoint instance. + artifact_uri (str): + The path to the directory where Model Artifacts will be saved. If + not set, will use staging bucket set in aiplatform.init(). + **kwargs: + Any kwargs will be passed to aiplatform.Model registration. + + Returns: + model (aiplatform.Model): + Instantiated representation of the uploaded model resource. + + Raises: + ValueError: Invalid Argument. + """ + artifact_uri = artifact_uri or initializer.global_config.staging_bucket + predict_utils.validate_artifact_uri(artifact_uri) + display_model_name = ( + f"ray-on-vertex-registered-xgboost-model-{utils.timestamped_unique_name()}" + ) + model = _get_xgboost_model_from(checkpoint) + + model_dir = os.path.join(artifact_uri, display_model_name) + file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME) + + with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file: + pickle.dump(model, temp_file) + gcs_utils.upload_to_gcs(temp_file.name, file_path) + return aiplatform.Model.upload_xgboost_model_file( + model_file_path=temp_file.name, + display_name=display_model_name, + xgboost_version=constants._XGBOOST_VERSION, + **kwargs, + ) + + +def _get_xgboost_model_from( + checkpoint: "ray_xgboost.XGBoostCheckpoint", +) -> "xgboost.Booster": + """Converts a XGBoostCheckpoint to XGBoost model. + + Args: + checkpoint: XGBoostCheckpoint instance. + + Returns: + A XGBoost core Booster + + Raises: + ValueError: Invalid Argument. + """ + if not isinstance(checkpoint, ray_xgboost.XGBoostCheckpoint): + raise ValueError( + "[Ray on Vertex AI]: arg checkpoint should be a" + " ray.train.xgboost.XGBoostCheckpoint instance" + ) + if checkpoint.get_preprocessor() is not None: + logging.warning( + "Checkpoint contains preprocessor. However, converting from a Ray" + " Checkpoint to framework specific model does NOT support" + " preprocessing. The model will be exported without preprocessors." + ) + return checkpoint.get_model() diff --git a/google/cloud/aiplatform/preview/vertex_ray/render.py b/google/cloud/aiplatform/preview/vertex_ray/render.py new file mode 100644 index 0000000000..acbe1b3797 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/render.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pathlib +from ray.widgets import Template + + +class VertexRayTemplate(Template): + """Class which provides basic HTML templating.""" + + def __init__(self, file: str): + with open(pathlib.Path(__file__).parent / "templates" / file, "r") as f: + self.template = f.read() diff --git a/google/cloud/aiplatform/preview/vertex_ray/templates/context.html.j2 b/google/cloud/aiplatform/preview/vertex_ray/templates/context.html.j2 new file mode 100644 index 0000000000..5524bee659 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/templates/context.html.j2 @@ -0,0 +1,49 @@ +
+
+

Ray

+ + + + + + + + + + + + + + + + + + + + + + + {{ shell_uri_row }} + + + + +
Python version:{{ python_version }}
Ray version: {{ ray_version }}
Vertex SDK version: {{ vertex_sdk_version }}
Dashboard:{{ dashboard_url }}
Cluster Name: {{ persistent_resource_id }}
+
+
diff --git a/google/cloud/aiplatform/preview/vertex_ray/templates/context_shellurirow.html.j2 b/google/cloud/aiplatform/preview/vertex_ray/templates/context_shellurirow.html.j2 new file mode 100644 index 0000000000..748ffe9a09 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/templates/context_shellurirow.html.j2 @@ -0,0 +1,4 @@ + + Interactive Terminal Uri: + {{ shell_uri }} + \ No newline at end of file diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py b/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py new file mode 100644 index 0000000000..251f915b21 --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import logging +import time +from typing import Optional + +from google.api_core import exceptions +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.utils import ( + PersistentResourceClientWithOverride, +) +from google.cloud.aiplatform.preview.vertex_ray.util import _validation_utils +from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( + Cluster, + Resources, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + PersistentResource, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import ( + GetPersistentResourceRequest, +) + + +def create_persistent_resource_client(): + # location is inhereted from the global configuration at aiplatform.init(). + return initializer.global_config.create_client( + client_class=PersistentResourceClientWithOverride, + appended_gapic_version="vertex_ray", + ) + + +def polling_delay(num_attempts: int, time_scale: float) -> datetime.timedelta: + """Computes a delay to the next attempt to poll the Vertex service. + + This does bounded exponential backoff, starting with $time_scale. + If $time_scale == 0, it starts with a small time interval, less than + 1 second. + + Args: + num_attempts: The number of times have we polled and found that the + desired result was not yet available. + time_scale: The shortest polling interval, in seconds, or zero. Zero is + treated as a small interval, less than 1 second. + + Returns: + A recommended delay interval, in seconds. + """ + # The polling schedule is slow initially , and then gets faster until 6 + # attempts (after that the sleeping time remains the same). + small_interval = 30.0 # Seconds + interval = max(time_scale, small_interval) * 0.765 ** min(num_attempts, 6) + return datetime.timedelta(seconds=interval) + + +def get_persistent_resource( + persistent_resource_name: str, tolerance: Optional[int] = 0 +): + """Get persistent resource. + + Args: + persistent_resource_name: + "projects//locations//persistentResources/". + tolerance: number of attemps to get persistent resource. + + Returns: + aiplatform_v1beta1.PersistentResource if state is RUNNING. + + Raises: + ValueError: Invalid cluster resource name. + RuntimeError: Service returns error. + RuntimeError: Cluster resource state is STOPPING. + RuntimeError: Cluster resource state is ERROR. + """ + + client = create_persistent_resource_client() + request = GetPersistentResourceRequest(name=persistent_resource_name) + + # TODO(b/277117901): Add test cases for polling and error handling + num_attempts = 0 + while True: + try: + response = client.get_persistent_resource(request) + except exceptions.NotFound: + response = None + if num_attempts >= tolerance: + raise ValueError( + "[Ray on Vertex AI]: Invalid cluster_resource_name (404 not found)." + ) + if response: + if response.error.message: + logging.error("[Ray on Vertex AI]: %s", response.error.message) + raise RuntimeError("[Ray on Vertex AI]: Cluster returned an error.") + + print("[Ray on Vertex AI]: Cluster State =", response.state) + if response.state == PersistentResource.State.RUNNING: + return response + elif response.state == PersistentResource.State.STOPPING: + raise RuntimeError("[Ray on Vertex AI]: The cluster is stopping.") + elif response.state == PersistentResource.State.ERROR: + raise RuntimeError( + "[Ray on Vertex AI]: The cluster encountered an error." + ) + # Polling decay + sleep_time = polling_delay(num_attempts=num_attempts, time_scale=150.0) + num_attempts += 1 + print( + "Waiting for cluster provisioning; attempt {}; sleeping for {} seconds".format( + num_attempts, sleep_time + ) + ) + time.sleep(sleep_time.total_seconds()) + + +def persistent_resource_to_cluster( + persistent_resource: PersistentResource, +) -> Cluster: + """Format a PersistentResource to a dictionary. + + Args: + persistent_resource: PersistentResource. + Returns: + Cluster. + """ + cluster = Cluster( + cluster_resource_name=persistent_resource.name, + network=persistent_resource.network, + state=persistent_resource.state.name, + ) + if not persistent_resource.resource_runtime_spec.ray_spec: + # skip PersistentResource without RaySpec + logging.info( + "[Ray on Vertex AI]: Cluster %s does not have Ray installed.", + persistent_resource.name, + ) + return + + image_uri = persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[ + "head-node" + ] + if image_uri is None: + image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri + python_version, ray_version = _validation_utils.get_versions_from_image_uri( + image_uri + ) + cluster.python_version = python_version + cluster.ray_version = ray_version + + resource_pools = persistent_resource.resource_pools + + head_resource_pool = resource_pools[0] + accelerator_type = head_resource_pool.machine_spec.accelerator_type + if accelerator_type.value != 0: + accelerator_type = accelerator_type.name + else: + accelerator_type = None + head_node_type = Resources( + machine_type=head_resource_pool.machine_spec.machine_type, + accelerator_type=accelerator_type, + accelerator_count=head_resource_pool.machine_spec.accelerator_count, + boot_disk_type=head_resource_pool.disk_spec.boot_disk_type, + boot_disk_size_gb=head_resource_pool.disk_spec.boot_disk_size_gb, + node_count=1, + ) + worker_node_types = [] + if head_resource_pool.replica_count > 1: + # head_node_type.node_count must be 1. If the head_resource_pool (the first + # resource pool) has replica_count > 1, the rest replica are worker nodes. + worker_node_count = head_resource_pool.replica_count - 1 + worker_node_types.append( + Resources( + machine_type=head_resource_pool.machine_spec.machine_type, + accelerator_type=accelerator_type, + accelerator_count=head_resource_pool.machine_spec.accelerator_count, + boot_disk_type=head_resource_pool.disk_spec.boot_disk_type, + boot_disk_size_gb=head_resource_pool.disk_spec.boot_disk_size_gb, + node_count=worker_node_count, + ) + ) + for i in range(len(resource_pools) - 1): + # Convert the second and more resource pools to vertex_ray.Resources, + # and append then to worker_node_types. + accelerator_type = resource_pools[i + 1].machine_spec.accelerator_type + if accelerator_type.value != 0: + accelerator_type = accelerator_type.name + else: + accelerator_type = None + worker_node_types.append( + Resources( + machine_type=resource_pools[i + 1].machine_spec.machine_type, + accelerator_type=accelerator_type, + accelerator_count=resource_pools[i + 1].machine_spec.accelerator_count, + boot_disk_type=resource_pools[i + 1].disk_spec.boot_disk_type, + boot_disk_size_gb=resource_pools[i + 1].disk_spec.boot_disk_size_gb, + node_count=resource_pools[i + 1].replica_count, + ) + ) + + cluster.head_node_type = head_node_type + cluster.worker_node_types = worker_node_types + + return cluster diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py b/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py new file mode 100644 index 0000000000..9aceb0872c --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import re + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.utils import resource_manager_utils + + +# Artifact Repository available regions. +_AVAILABLE_REGIONS = ["us", "europe", "asia"] +# If region is not available, assume using the default region. +_DEFAULT_REGION = "us" + +_PERSISTENT_RESOURCE_NAME_PATTERN = "projects/{}/locations/{}/persistentResources/{}" +_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}" + + +def valid_resource_name(resource_name): + """Check if address is a valid resource name.""" + resource_name_split = resource_name.split("/") + if not ( + len(resource_name_split) == 6 + and resource_name_split[0] == "projects" + and resource_name_split[2] == "locations" + and resource_name_split[4] == "persistentResources" + ): + raise ValueError( + "[Ray on Vertex AI]: Address must be in the following " + "format: vertex_ray://projects//locations//persistentResources/ " + "or vertex_ray://." + ) + + +def maybe_reconstruct_resource_name(address) -> str: + """Reconstruct full persistent resource name if only id was given.""" + if re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), address): + # Assume only cluster name (persistent resource id) was given. + logging.info( + "[Ray on Vertex AI]: Cluster name was given as address, reconstructing full resource name" + ) + return _PERSISTENT_RESOURCE_NAME_PATTERN.format( + resource_manager_utils.get_project_number( + initializer.global_config.project + ), + initializer.global_config.location, + address, + ) + + return address + + +def get_image_uri(ray_version, python_version, enable_cuda): + """Image uri for a given ray version and python version.""" + if ray_version not in ["2_4"]: + raise ValueError("[Ray on Vertex AI]: The supported Ray version is 2_4.") + if python_version not in ["3_10"]: + raise ValueError("[Ray on Vertex AI]: The supported Python version is 3_10.") + + location = initializer.global_config.location + region = location.split("-")[0] + if region not in _AVAILABLE_REGIONS: + region = _DEFAULT_REGION + + if enable_cuda: + # TODO(b/292003337) update eligible image uris + return f"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-4.py310:latest" + else: + return f"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-4.py310:latest" + + +def get_versions_from_image_uri(image_uri): + """Get ray version and python version from image uri.""" + logging.info(f"[Ray on Vertex AI]: Getting versions from image uri: {image_uri}") + image_label = image_uri.split("/")[-1].split(":")[0] + py_version = image_label[-3] + "_" + image_label[-2:] + ray_version = image_label.split(".")[1].replace("-", "_") + return py_version, ray_version diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/resources.py b/google/cloud/aiplatform/preview/vertex_ray/util/resources.py new file mode 100644 index 0000000000..e7a0e58eaf --- /dev/null +++ b/google/cloud/aiplatform/preview/vertex_ray/util/resources.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +from typing import List, Optional +from google.cloud.aiplatform_v1beta1.types import PersistentResource + + +class Resources: + """Resources for a ray cluster node. + + Attributes: + machine_type: See the list of machine types: + https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types + node_count: This argument represents how many nodes to start for the + ray cluster. + accelerator_type: e.g. "NVIDIA_TESLA_P4". + Vertex AI supports the following types of GPU: + https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus + accelerator_count: The number of accelerators to attach to the machine. + boot_disk_type: Type of the boot disk (default is "pd-ssd"). + Valid values: "pd-ssd" (Persistent Disk Solid State Drive) or + "pd-standard" (Persistent Disk Hard Disk Drive). + boot_disk_size_gb: Size in GB of the boot disk (default is 100GB). Must + be either unspecified or within the range of [100, 64000]. + """ + + def __init__( + self, + machine_type: Optional[str] = "n1-standard-4", + node_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = 0, + boot_disk_type: Optional[str] = "pd-ssd", + boot_disk_size_gb: Optional[int] = 100, + ): + + self.machine_type = machine_type + self.node_count = node_count + self.accelerator_type = accelerator_type + self.accelerator_count = accelerator_count + self.boot_disk_type = boot_disk_type + self.boot_disk_size_gb = boot_disk_size_gb + + if accelerator_type is None and accelerator_count > 0: + raise ValueError( + "[Ray on Vertex]: accelerator_type must be specified when" + + " accelerator_count is set to a value other than 0." + ) + + +@dataclasses.dataclass +class Cluster: + """Ray cluster (output only). + + Attributes: + cluster_resource_name: It has a format: + "projects//locations//persistentResources/". + network: Virtual private cloud (VPC) network. It has a format: + "projects//global/networks/". + For Ray Client, VPC peering is required to connect to the cluster + managed in the Vertex API service. For Ray Job API, VPC network is + not required because cluster connection can be accessed through + dashboard address. + state: Describes the cluster state (defined in PersistentResource.State). + python_version: Python version for the ray cluster (e.g. "3_10"). + ray_version: Ray version for the ray cluster (e.g. "2_4"). + head_node_type: The head node resource. Resources.node_count must be 1. + If not set, by default it is a CPU node with machine_type of n1-standard-4. + worker_node_types: The list of Resources of the worker nodes. Should not + duplicate the elements in the list. + """ + + cluster_resource_name: str = None + network: str = None + state: PersistentResource.State = None + python_version: str = None + ray_version: str = None + head_node_type: Resources = None + worker_node_types: List[Resources] = None + + +def _check_machine_spec_identical( + node_type_1: Resources, + node_type_2: Resources, +) -> int: + """Check if node_type_1 and node_type_2 have the same machine_spec. + If they are identical, return additional_replica_count.""" + additional_replica_count = 0 + + # Check if machine_spec are the same + if ( + node_type_1.machine_type == node_type_2.machine_type + and node_type_1.accelerator_type == node_type_2.accelerator_type + and node_type_1.accelerator_count == node_type_2.accelerator_count + ): + additional_replica_count = node_type_2.node_count + return additional_replica_count + + return additional_replica_count diff --git a/setup.py b/setup.py index ffeca3133d..ca0cb446b7 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,13 @@ if package.startswith("google") or package.startswith("vertexai") ] +# Add vertex_ray relative packages +packages += [ + package.replace("google.cloud.aiplatform.preview.vertex_ray", "vertex_ray") + for package in setuptools.PEP420PackageFinder.find() + if package.startswith("google.cloud.aiplatform.preview.vertex_ray") +] + tensorboard_extra_require = ["tensorflow >=2.3.0, <3.0.0dev"] metadata_extra_require = ["pandas >= 1.0.0", "numpy>=1.15.0"] xai_extra_require = ["tensorflow >=2.3.0, <3.0.0dev"] @@ -91,6 +98,20 @@ "importlib-metadata < 7.0; python_version<'3.8'", ] +ray_extra_require = [ + # Ray's dependency version must be kept in sync with what Cluster supports. + "ray[default] >= 2.4, < 2.5; python_version<'3.11'", + # Ray Data v2.4 in Python 3.11 is broken, but got fixed in Ray v2.5. + "ray[default] >= 2.5, < 2.5.1; python_version>='3.11'", + "google-cloud-bigquery-storage", + "google-cloud-bigquery", + "pandas >= 1.0.0", + "pyarrow >= 6.0.1", + # Workaround for https://github.com/ray-project/ray/issues/36990. + # TODO(b/295406381): Remove this pin when we drop support of ray<=2.5. + "pydantic < 2", +] + full_extra_require = list( set( tensorboard_extra_require @@ -106,6 +127,7 @@ + private_endpoints_extra_require + autologging_extra_require + preview_extra_require + + ray_extra_require ) ) testing_extra_require = ( @@ -123,6 +145,7 @@ "torch >= 2.0.0; python_version>='3.8'", "torch; python_version<'3.8'", "xgboost", + "xgboost_ray", ] ) @@ -133,6 +156,8 @@ description=description, long_description=readme, packages=packages, + package_dir={"vertex_ray": "google/cloud/aiplatform/preview/vertex_ray"}, + package_data={"": ["*.html.j2"]}, entry_points={ "console_scripts": [ "tb-gcp-uploader=google.cloud.aiplatform.tensorboard.uploader_main:run_main" @@ -171,6 +196,7 @@ "private_endpoints": private_endpoints_extra_require, "autologging": autologging_extra_require, "preview": preview_extra_require, + "ray": ray_extra_require, }, python_requires=">=3.7", classifiers=[ diff --git a/testing/constraints-3.10.txt b/testing/constraints-3.10.txt index 6c3e6c5bbc..e51a4c5b2b 100644 --- a/testing/constraints-3.10.txt +++ b/testing/constraints-3.10.txt @@ -3,7 +3,7 @@ # List all library dependencies and extras in this file. google-api-core==1.32.0 proto-plus==1.22.0 -protobuf==3.19.5 +protobuf==3.19.6 mock==4.0.2 google-cloud-storage==2.0.0 packaging==20.0 # Increased for compatibility with MLFlow diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 9f1b48e7ae..4e4013f5bc 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -7,7 +7,7 @@ # Then this file should have foo==1.14.0 google-api-core==1.32.0 proto-plus==1.22.0 -protobuf==3.19.5 +protobuf==3.19.6 mock==4.0.2 google-cloud-storage==2.0.0 packaging==20.0 # Increased for compatibility with MLFlow diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index 6c3e6c5bbc..e51a4c5b2b 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -3,7 +3,7 @@ # List all library dependencies and extras in this file. google-api-core==1.32.0 proto-plus==1.22.0 -protobuf==3.19.5 +protobuf==3.19.6 mock==4.0.2 google-cloud-storage==2.0.0 packaging==20.0 # Increased for compatibility with MLFlow diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index 6c3e6c5bbc..e51a4c5b2b 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -3,7 +3,7 @@ # List all library dependencies and extras in this file. google-api-core==1.32.0 proto-plus==1.22.0 -protobuf==3.19.5 +protobuf==3.19.6 mock==4.0.2 google-cloud-storage==2.0.0 packaging==20.0 # Increased for compatibility with MLFlow diff --git a/tests/unit/vertex_ray/conftest.py b/tests/unit/vertex_ray/conftest.py new file mode 100644 index 0000000000..26666deea6 --- /dev/null +++ b/tests/unit/vertex_ray/conftest.py @@ -0,0 +1,131 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google import auth +from google.api_core import exceptions +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials +from google.cloud import resourcemanager +from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( + PersistentResourceServiceClient, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + PersistentResource, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + ResourceRuntime, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import ( + DeletePersistentResourceRequest, +) +import test_constants as tc +import mock +import pytest + + +# -*- coding: utf-8 -*- + +# STOPPING +_TEST_RESPONSE_STOPPING = PersistentResource() +_TEST_RESPONSE_STOPPING.name = tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS +resource_runtime = ResourceRuntime() +_TEST_RESPONSE_STOPPING.resource_runtime = resource_runtime +_TEST_RESPONSE_STOPPING.state = "STOPPING" + +# ERROR +_TEST_RESPONSE_ERROR = PersistentResource() +_TEST_RESPONSE_ERROR.name = tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS +resource_runtime = ResourceRuntime() +_TEST_RESPONSE_ERROR.resource_runtime = resource_runtime +_TEST_RESPONSE_ERROR.state = "ERROR" + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as auth_mock: + auth_mock.return_value = ( + auth_credentials.AnonymousCredentials(), + tc.ProjectConstants._TEST_GCP_PROJECT_ID, + ) + yield auth_mock + + +@pytest.fixture +def get_project_number_mock(): + with mock.patch.object( + resourcemanager.ProjectsClient, "get_project" + ) as get_project_number_mock: + test_project = resourcemanager.Project( + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID + ) + test_project.name = f"projects/{tc.ProjectConstants._TEST_GCP_PROJECT_NUMBER}" + get_project_number_mock.return_value = test_project + yield get_project_number_mock + + +@pytest.fixture +def api_client_mock(): + yield mock.create_autospec( + PersistentResourceServiceClient, spec_set=True, instance=True + ) + + +@pytest.fixture +def persistent_client_mock(api_client_mock): + with mock.patch.object( + vertex_ray.util._gapic_utils, + "create_persistent_resource_client", + ) as persistent_client_mock: + + # get_persistent_resource + api_client_mock.get_persistent_resource.return_value = ( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL + ) + # delete_persistent_resource + delete_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + delete_persistent_resource_lro_mock.result.return_value = ( + DeletePersistentResourceRequest() + ) + api_client_mock.delete_persistent_resource.return_value = ( + delete_persistent_resource_lro_mock + ) + + persistent_client_mock.return_value = api_client_mock + yield persistent_client_mock + + +@pytest.fixture +def persistent_client_stopping_mock(api_client_mock): + with mock.patch.object( + vertex_ray.util._gapic_utils, "create_persistent_resource_client" + ) as persistent_client_stopping_mock: + api_client_mock.get_persistent_resource.return_value = _TEST_RESPONSE_STOPPING + persistent_client_stopping_mock.return_value = api_client_mock + yield persistent_client_stopping_mock + + +@pytest.fixture +def persistent_client_error_mock(api_client_mock): + with mock.patch.object( + vertex_ray.util._gapic_utils, "create_persistent_resource_client" + ) as persistent_client_error_mock: + # get_persistent_resource + api_client_mock.get_persistent_resource.return_value = _TEST_RESPONSE_ERROR + # delete_persistent_resource + api_client_mock.delete_persistent_resource.side_effect = exceptions.NotFound + + persistent_client_error_mock.return_value = api_client_mock + yield persistent_client_error_mock diff --git a/tests/unit/vertex_ray/test_bigquery.py b/tests/unit/vertex_ray/test_bigquery.py new file mode 100644 index 0000000000..6120920483 --- /dev/null +++ b/tests/unit/vertex_ray/test_bigquery.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib + +from google.api_core import exceptions +from google.api_core import operation +from google.cloud import bigquery +from google.cloud import bigquery_storage +from google.cloud import aiplatform +from google.cloud.aiplatform.preview.vertex_ray import bigquery_datasource +import test_constants as tc +from google.cloud.bigquery import job +from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream +import mock +import pytest +import ray + + +_TEST_BQ_DATASET_ID = "mockdataset" +_TEST_BQ_TABLE_ID = "mocktable" +_TEST_BQ_DATASET = _TEST_BQ_DATASET_ID + "." + _TEST_BQ_TABLE_ID +_TEST_BQ_TEMP_DESTINATION = ( + tc.ProjectConstants._TEST_GCP_PROJECT_ID + ".tempdataset.temptable" +) +_TEST_DISPLAY_NAME = "display_name" + + +@pytest.fixture(autouse=True) +def bq_client_full_mock(monkeypatch): + client_mock = mock.create_autospec(bigquery.Client) + client_mock.return_value = client_mock + + def bq_get_dataset_mock(dataset_id): + if dataset_id != _TEST_BQ_DATASET_ID: + raise exceptions.NotFound( + "[Ray on Vertex AI]: Dataset {} is not found. Please ensure that it exists.".format( + _TEST_BQ_DATASET + ) + ) + + def bq_get_table_mock(table_id): + if table_id != _TEST_BQ_DATASET: + raise exceptions.NotFound( + "[Ray on Vertex AI]: Table {} is not found. Please ensure that it exists.".format( + _TEST_BQ_DATASET + ) + ) + + def bq_create_dataset_mock(dataset_id, **kwargs): + if dataset_id == "existingdataset": + raise exceptions.Conflict("Dataset already exists") + return mock.Mock(operation.Operation) + + def bq_delete_table_mock(table, **kwargs): + return None + + def bq_query_mock(query): + fake_job_ref = job._JobReference( + "fake_job_id", + tc.ProjectConstants._TEST_GCP_PROJECT_ID, + "us-central1", + ) + fake_query_job = job.QueryJob(fake_job_ref, query, None) + fake_query_job.configuration.destination = _TEST_BQ_TEMP_DESTINATION + return fake_query_job + + client_mock.get_dataset = bq_get_dataset_mock + client_mock.get_table = bq_get_table_mock + client_mock.create_dataset = bq_create_dataset_mock + client_mock.delete_table = bq_delete_table_mock + client_mock.query = bq_query_mock + + monkeypatch.setattr(bigquery, "Client", client_mock) + client_mock.reset_mock() + return client_mock + + +@pytest.fixture(autouse=True) +def bqs_client_full_mock(monkeypatch): + client_mock = mock.create_autospec(bigquery_storage.BigQueryReadClient) + client_mock.return_value = client_mock + + def bqs_create_read_session(max_stream_count=0, **kwargs): + read_session_proto = gcbqs_stream.ReadSession() + read_session_proto.streams = [ + gcbqs_stream.ReadStream() for _ in range(max_stream_count) + ] + return read_session_proto + + client_mock.create_read_session = bqs_create_read_session + + monkeypatch.setattr(bigquery_storage, "BigQueryReadClient", client_mock) + client_mock.reset_mock() + return client_mock + + +@pytest.fixture +def bq_query_result_mock(): + with mock.patch.object(bigquery.job.QueryJob, "result") as query_result_mock: + yield query_result_mock + + +@pytest.fixture +def bq_query_result_mock_fail(): + with mock.patch.object(bigquery.job.QueryJob, "result") as query_result_mock_fail: + query_result_mock_fail.side_effect = exceptions.BadRequest("400 Syntax error") + yield query_result_mock_fail + + +@pytest.fixture +def ray_remote_function_mock(): + with mock.patch.object(ray.remote_function.RemoteFunction, "_remote") as remote_fn: + remote_fn.return_value = 1 + yield remote_fn + + +class TestReadBigQuery: + """Tests for BigQuery Read.""" + + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize( + "parallelism", + [1, 2, 3, 4, 10, 100], + ) + def test_create_reader(self, parallelism): + bq_ds = bigquery_datasource.BigQueryDatasource() + reader = bq_ds.create_reader( + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + dataset=_TEST_BQ_DATASET, + parallelism=parallelism, + ) + read_tasks_list = reader.get_read_tasks(parallelism) + assert len(read_tasks_list) == parallelism + + @pytest.mark.parametrize( + "parallelism", + [1, 2, 3, 4, 10, 100], + ) + def test_create_reader_initialized(self, parallelism): + """If initialized, create_reader doesn't need to specify project_id.""" + aiplatform.init( + project=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + bq_ds = bigquery_datasource.BigQueryDatasource() + reader = bq_ds.create_reader( + dataset=_TEST_BQ_DATASET, + parallelism=parallelism, + ) + read_tasks_list = reader.get_read_tasks(parallelism) + assert len(read_tasks_list) == parallelism + + @pytest.mark.parametrize( + "parallelism", + [1, 2, 3, 4, 10, 100], + ) + def test_create_reader_query(self, parallelism, bq_query_result_mock): + bq_ds = bigquery_datasource.BigQueryDatasource() + reader = bq_ds.create_reader( + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + parallelism=parallelism, + query="SELECT * FROM mockdataset.mocktable", + ) + read_tasks_list = reader.get_read_tasks(parallelism) + bq_query_result_mock.assert_called_once() + assert len(read_tasks_list) == parallelism + + @pytest.mark.parametrize( + "parallelism", + [1, 2, 3, 4, 10, 100], + ) + def test_create_reader_query_bad_request( + self, + parallelism, + bq_query_result_mock_fail, + ): + bq_ds = bigquery_datasource.BigQueryDatasource() + reader = bq_ds.create_reader( + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + parallelism=parallelism, + query="SELECT * FROM mockdataset.mocktable", + ) + with pytest.raises(exceptions.BadRequest): + reader.get_read_tasks(parallelism) + bq_query_result_mock_fail.assert_called() + + def test_dataset_query_kwargs_provided(self): + parallelism = 4 + bq_ds = bigquery_datasource.BigQueryDatasource() + with pytest.raises(ValueError) as exception: + bq_ds.create_reader( + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + dataset=_TEST_BQ_DATASET, + query="SELECT * FROM mockdataset.mocktable", + parallelism=parallelism, + ) + expected_message = "[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)." + assert str(exception.value) == expected_message + + def test_create_reader_dataset_not_found(self): + parallelism = 4 + bq_ds = bigquery_datasource.BigQueryDatasource() + reader = bq_ds.create_reader( + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + dataset="nonexistentdataset.mocktable", + parallelism=parallelism, + ) + with pytest.raises(ValueError) as exception: + reader.get_read_tasks(parallelism) + expected_message = "[Ray on Vertex AI]: Dataset nonexistentdataset is not found. Please ensure that it exists." + assert str(exception.value) == expected_message + + def test_create_reader_table_not_found(self): + parallelism = 4 + bq_ds = bigquery_datasource.BigQueryDatasource() + reader = bq_ds.create_reader( + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + dataset="mockdataset.nonexistenttable", + parallelism=parallelism, + ) + with pytest.raises(ValueError) as exception: + reader.get_read_tasks(parallelism) + expected_message = "[Ray on Vertex AI]: Table mockdataset.nonexistenttable is not found. Please ensure that it exists." + assert str(exception.value) == expected_message + + +@pytest.mark.usefixtures("google_auth_mock") +class TestWriteBigQuery: + """Tests for BigQuery Write.""" + + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + def test_do_write(self, ray_remote_function_mock): + bq_ds = bigquery_datasource.BigQueryDatasource() + write_tasks_list = bq_ds.do_write( + blocks=[1, 2, 3, 4], + metadata=[1, 2, 3, 4], + ray_remote_args={}, + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + dataset=_TEST_BQ_DATASET, + ) + assert len(write_tasks_list) == 4 + + def test_do_write_initialized(self, ray_remote_function_mock): + """If initialized, do_write doesn't need to specify project_id.""" + aiplatform.init( + project=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + bq_ds = bigquery_datasource.BigQueryDatasource() + write_tasks_list = bq_ds.do_write( + blocks=[1, 2, 3, 4], + metadata=[1, 2, 3, 4], + ray_remote_args={}, + dataset=_TEST_BQ_DATASET, + ) + assert len(write_tasks_list) == 4 + + def test_do_write_dataset_exists(self, ray_remote_function_mock): + bq_ds = bigquery_datasource.BigQueryDatasource() + write_tasks_list = bq_ds.do_write( + blocks=[1, 2, 3, 4], + metadata=[1, 2, 3, 4], + ray_remote_args={}, + project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + dataset="existingdataset" + "." + _TEST_BQ_TABLE_ID, + ) + assert len(write_tasks_list) == 4 diff --git a/tests/unit/vertex_ray/test_cluster_init.py b/tests/unit/vertex_ray/test_cluster_init.py new file mode 100644 index 0000000000..1556c6e86d --- /dev/null +++ b/tests/unit/vertex_ray/test_cluster_init.py @@ -0,0 +1,470 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import copy +import importlib + +from google.api_core import operation as ga_operation +from google.cloud import aiplatform +from google.cloud.aiplatform.preview import vertex_ray +from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( + Resources, +) +from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( + PersistentResourceServiceClient, +) +from google.cloud.aiplatform_v1beta1.types import persistent_resource_service +import test_constants as tc +import mock +import pytest + +from google.protobuf import field_mask_pb2 # type: ignore + + +# -*- coding: utf-8 -*- + +_EXPECTED_MASK = field_mask_pb2.FieldMask(paths=["resource_pools.replica_count"]) + +# for manual scaling +_TEST_RESPONSE_RUNNING_1_POOL_RESIZE = copy.deepcopy( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL +) +_TEST_RESPONSE_RUNNING_1_POOL_RESIZE.resource_pools[0].replica_count = 2 +_TEST_RESPONSE_RUNNING_2_POOLS_RESIZE = copy.deepcopy( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_2_POOLS +) +_TEST_RESPONSE_RUNNING_2_POOLS_RESIZE.resource_pools[1].replica_count = 1 + + +@pytest.fixture +def create_persistent_resource_1_pool_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "create_persistent_resource", + ) as create_persistent_resource_1_pool_mock: + create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + create_persistent_resource_lro_mock.result.return_value = ( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL + ) + create_persistent_resource_1_pool_mock.return_value = ( + create_persistent_resource_lro_mock + ) + yield create_persistent_resource_1_pool_mock + + +@pytest.fixture +def get_persistent_resource_1_pool_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as get_persistent_resource_1_pool_mock: + get_persistent_resource_1_pool_mock.return_value = ( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL + ) + yield get_persistent_resource_1_pool_mock + + +@pytest.fixture +def create_persistent_resource_2_pools_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "create_persistent_resource", + ) as create_persistent_resource_2_pools_mock: + create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + create_persistent_resource_lro_mock.result.return_value = ( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_2_POOLS + ) + create_persistent_resource_2_pools_mock.return_value = ( + create_persistent_resource_lro_mock + ) + yield create_persistent_resource_2_pools_mock + + +@pytest.fixture +def get_persistent_resource_2_pools_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as get_persistent_resource_2_pools_mock: + get_persistent_resource_2_pools_mock.return_value = ( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_2_POOLS + ) + yield get_persistent_resource_2_pools_mock + + +@pytest.fixture +def list_persistent_resources_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "list_persistent_resources", + ) as list_persistent_resources_mock: + list_persistent_resources_mock.return_value = [ + tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL, + tc.ClusterConstants._TEST_RESPONSE_NO_RAY_RUNNING, # should be ignored + tc.ClusterConstants._TEST_RESPONSE_RUNNING_2_POOLS, + ] + yield list_persistent_resources_mock + + +@pytest.fixture +def create_persistent_resource_exception_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "create_persistent_resource", + ) as create_persistent_resource_exception_mock: + create_persistent_resource_exception_mock.side_effect = Exception + yield create_persistent_resource_exception_mock + + +@pytest.fixture +def get_persistent_resource_exception_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "get_persistent_resource", + ) as get_persistent_resource_exception_mock: + get_persistent_resource_exception_mock.side_effect = Exception + yield get_persistent_resource_exception_mock + + +@pytest.fixture +def list_persistent_resources_exception_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "list_persistent_resources", + ) as list_persistent_resources_exception_mock: + list_persistent_resources_exception_mock.side_effect = Exception + yield list_persistent_resources_exception_mock + + +@pytest.fixture +def update_persistent_resource_1_pool_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "update_persistent_resource", + ) as update_persistent_resource_1_pool_mock: + update_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + update_persistent_resource_lro_mock.result.return_value = ( + _TEST_RESPONSE_RUNNING_1_POOL_RESIZE + ) + update_persistent_resource_1_pool_mock.return_value = ( + update_persistent_resource_lro_mock + ) + yield update_persistent_resource_1_pool_mock + + +@pytest.fixture +def update_persistent_resource_2_pools_mock(): + with mock.patch.object( + PersistentResourceServiceClient, + "update_persistent_resource", + ) as update_persistent_resource_2_pools_mock: + update_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation) + update_persistent_resource_lro_mock.result.return_value = ( + _TEST_RESPONSE_RUNNING_2_POOLS_RESIZE + ) + update_persistent_resource_2_pools_mock.return_value = ( + update_persistent_resource_lro_mock + ) + yield update_persistent_resource_2_pools_mock + + +@pytest.mark.usefixtures("google_auth_mock", "get_project_number_mock") +class TestClusterManagement: + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + aiplatform.init() + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("get_persistent_resource_1_pool_mock") + def test_create_ray_cluster_1_pool_gpu_success( + self, create_persistent_resource_1_pool_mock + ): + """If head and worker nodes are duplicate, merge to head pool.""" + cluster_name = vertex_ray.create_ray_cluster( + head_node_type=tc.ClusterConstants._TEST_HEAD_NODE_TYPE_1_POOL, + worker_node_types=tc.ClusterConstants._TEST_WORKER_NODE_TYPES_1_POOL, + network=tc.ProjectConstants._TEST_VPC_NETWORK, + cluster_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID, + ) + + assert tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS == cluster_name + + request = persistent_resource_service.CreatePersistentResourceRequest( + parent=tc.ProjectConstants._TEST_PARENT, + persistent_resource=tc.ClusterConstants._TEST_REQUEST_RUNNING_1_POOL, + persistent_resource_id=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID, + ) + + create_persistent_resource_1_pool_mock.assert_called_with( + request, + ) + + @pytest.mark.usefixtures("get_persistent_resource_2_pools_mock") + def test_create_ray_cluster_2_pools_success( + self, create_persistent_resource_2_pools_mock + ): + """If head and worker nodes are not duplicate, create separate resource_pools.""" + cluster_name = vertex_ray.create_ray_cluster( + head_node_type=tc.ClusterConstants._TEST_HEAD_NODE_TYPE_2_POOLS, + worker_node_types=tc.ClusterConstants._TEST_WORKER_NODE_TYPES_2_POOLS, + network=tc.ProjectConstants._TEST_VPC_NETWORK, + cluster_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID, + ) + + assert tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS == cluster_name + request = persistent_resource_service.CreatePersistentResourceRequest( + parent=tc.ProjectConstants._TEST_PARENT, + persistent_resource=tc.ClusterConstants._TEST_REQUEST_RUNNING_2_POOLS, + persistent_resource_id=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID, + ) + + create_persistent_resource_2_pools_mock.assert_called_with( + request, + ) + + @pytest.mark.usefixtures("persistent_client_mock") + def test_create_ray_cluster_initialized_success( + self, get_project_number_mock, api_client_mock + ): + """If initialized, create_ray_cluster doesn't need many call args.""" + aiplatform.init( + project=tc.ProjectConstants._TEST_GCP_PROJECT_ID_OVERRIDE, + location=tc.ProjectConstants._TEST_GCP_REGION_OVERRIDE, + staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + + _ = vertex_ray.create_ray_cluster( + network=tc.ProjectConstants._TEST_VPC_NETWORK, + ) + + create_method_mock = api_client_mock.create_persistent_resource + + # Assert that project override took effect. + get_project_number_mock.assert_called_once_with( + name="projects/{}".format(tc.ProjectConstants._TEST_GCP_PROJECT_ID_OVERRIDE) + ) + # Assert that location override took effect. + assert ( + tc.ProjectConstants._TEST_GCP_REGION_OVERRIDE + in create_method_mock.call_args.args[0].parent + ) + assert ( + "asia-docker" + in create_method_mock.call_args.args[ + 0 + ].persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[ + "head-node" + ] + ) + + def test_create_ray_cluster_head_multinode_error(self): + with pytest.raises(ValueError) as e: + vertex_ray.create_ray_cluster( + head_node_type=Resources(node_count=3), + network=tc.ProjectConstants._TEST_VPC_NETWORK, + ) + e.match(regexp=r"Resources.node_count must be 1.") + + def test_create_ray_cluster_python_version_error(self): + with pytest.raises(ValueError) as e: + vertex_ray.create_ray_cluster( + network=tc.ProjectConstants._TEST_VPC_NETWORK, + python_version="3_8", + ) + e.match(regexp=r"The supported Python version is 3_10.") + + def test_create_ray_cluster_ray_version_error(self): + with pytest.raises(ValueError) as e: + vertex_ray.create_ray_cluster( + network=tc.ProjectConstants._TEST_VPC_NETWORK, + ray_version="2_1", + ) + e.match(regexp=r"The supported Ray version is 2_4.") + + @pytest.mark.usefixtures("create_persistent_resource_exception_mock") + def test_create_ray_cluster_state_error(self): + with pytest.raises(ValueError) as e: + vertex_ray.create_ray_cluster( + network=tc.ProjectConstants._TEST_VPC_NETWORK, + ) + + e.match(regexp=r"Failed in cluster creation due to: ") + + def test_delete_ray_cluster_success(self, persistent_client_mock): + vertex_ray.delete_ray_cluster( + cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + + persistent_client_mock.assert_called_once() + + @pytest.mark.usefixtures("persistent_client_error_mock") + def test_delete_ray_cluster_error(self): + with pytest.raises(ValueError) as e: + vertex_ray.delete_ray_cluster( + cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + + e.match(regexp=r"Failed in cluster deletion due to: ") + + def test_get_ray_cluster_success(self, get_persistent_resource_1_pool_mock): + cluster = vertex_ray.get_ray_cluster( + cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + + get_persistent_resource_1_pool_mock.assert_called_once() + + assert vars(cluster.head_node_type) == vars( + tc.ClusterConstants._TEST_CLUSTER.head_node_type + ) + assert vars(cluster.worker_node_types[0]) == vars( + tc.ClusterConstants._TEST_CLUSTER.worker_node_types[0] + ) + assert ( + cluster.cluster_resource_name + == tc.ClusterConstants._TEST_CLUSTER.cluster_resource_name + ) + assert ( + cluster.python_version == tc.ClusterConstants._TEST_CLUSTER.python_version + ) + assert cluster.ray_version == tc.ClusterConstants._TEST_CLUSTER.ray_version + assert cluster.network == tc.ClusterConstants._TEST_CLUSTER.network + assert cluster.state == tc.ClusterConstants._TEST_CLUSTER.state + + @pytest.mark.usefixtures("get_persistent_resource_exception_mock") + def test_get_ray_cluster_error(self): + with pytest.raises(ValueError) as e: + vertex_ray.get_ray_cluster( + cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + + e.match(regexp=r"Failed in getting the cluster due to: ") + + def test_list_ray_clusters_success(self, list_persistent_resources_mock): + clusters = vertex_ray.list_ray_clusters() + + list_persistent_resources_mock.assert_called_once() + + # first ray cluster + assert vars(clusters[0].head_node_type) == vars( + tc.ClusterConstants._TEST_CLUSTER.head_node_type + ) + assert vars(clusters[0].worker_node_types[0]) == vars( + tc.ClusterConstants._TEST_CLUSTER.worker_node_types[0] + ) + assert ( + clusters[0].cluster_resource_name + == tc.ClusterConstants._TEST_CLUSTER.cluster_resource_name + ) + assert ( + clusters[0].python_version + == tc.ClusterConstants._TEST_CLUSTER.python_version + ) + assert clusters[0].ray_version == tc.ClusterConstants._TEST_CLUSTER.ray_version + assert clusters[0].network == tc.ClusterConstants._TEST_CLUSTER.network + assert clusters[0].state == tc.ClusterConstants._TEST_CLUSTER.state + + # second ray cluster + assert vars(clusters[1].head_node_type) == vars( + tc.ClusterConstants._TEST_CLUSTER_2.head_node_type + ) + assert vars(clusters[1].worker_node_types[0]) == vars( + tc.ClusterConstants._TEST_CLUSTER_2.worker_node_types[0] + ) + assert ( + clusters[1].cluster_resource_name + == tc.ClusterConstants._TEST_CLUSTER_2.cluster_resource_name + ) + assert ( + clusters[1].python_version + == tc.ClusterConstants._TEST_CLUSTER_2.python_version + ) + assert ( + clusters[1].ray_version == tc.ClusterConstants._TEST_CLUSTER_2.ray_version + ) + assert clusters[1].network == tc.ClusterConstants._TEST_CLUSTER_2.network + assert clusters[1].state == tc.ClusterConstants._TEST_CLUSTER_2.state + + def test_list_ray_clusters_initialized_success( + self, get_project_number_mock, list_persistent_resources_mock + ): + aiplatform.init( + project=tc.ProjectConstants._TEST_GCP_PROJECT_ID_OVERRIDE, + location=tc.ProjectConstants._TEST_GCP_REGION_OVERRIDE, + staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + _ = vertex_ray.list_ray_clusters() + + # Assert that project override took effect. + get_project_number_mock.assert_called_once_with( + name="projects/{}".format(tc.ProjectConstants._TEST_GCP_PROJECT_ID_OVERRIDE) + ) + # Assert that location override took effect. + assert ( + tc.ProjectConstants._TEST_GCP_REGION_OVERRIDE + in list_persistent_resources_mock.call_args.args[0].parent + ) + + @pytest.mark.usefixtures("list_persistent_resources_exception_mock") + def test_list_ray_clusters_error(self): + with pytest.raises(ValueError) as e: + vertex_ray.list_ray_clusters() + + e.match(regexp=r"Failed in listing the clusters due to: ") + + @pytest.mark.usefixtures("get_persistent_resource_1_pool_mock") + def test_update_ray_cluster_1_pool(self, update_persistent_resource_1_pool_mock): + + new_worker_node_types = [] + for worker_node_type in tc.ClusterConstants._TEST_CLUSTER.worker_node_types: + # resize worker node to node_count = 1 + worker_node_type.node_count = 1 + new_worker_node_types.append(worker_node_type) + + returned_name = vertex_ray.update_ray_cluster( + cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS, + worker_node_types=new_worker_node_types, + ) + + request = persistent_resource_service.UpdatePersistentResourceRequest( + persistent_resource=_TEST_RESPONSE_RUNNING_1_POOL_RESIZE, + update_mask=_EXPECTED_MASK, + ) + update_persistent_resource_1_pool_mock.assert_called_once_with(request) + + assert returned_name == tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + + @pytest.mark.usefixtures("get_persistent_resource_2_pools_mock") + def test_update_ray_cluster_2_pools(self, update_persistent_resource_2_pools_mock): + + new_worker_node_types = [] + for worker_node_type in tc.ClusterConstants._TEST_CLUSTER_2.worker_node_types: + # resize worker node to node_count = 1 + worker_node_type.node_count = 1 + new_worker_node_types.append(worker_node_type) + + returned_name = vertex_ray.update_ray_cluster( + cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS, + worker_node_types=new_worker_node_types, + ) + + request = persistent_resource_service.UpdatePersistentResourceRequest( + persistent_resource=_TEST_RESPONSE_RUNNING_2_POOLS_RESIZE, + update_mask=_EXPECTED_MASK, + ) + update_persistent_resource_2_pools_mock.assert_called_once_with(request) + + assert returned_name == tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py new file mode 100644 index 0000000000..8d972c08ea --- /dev/null +++ b/tests/unit/vertex_ray/test_constants.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import dataclasses + +from google.cloud.aiplatform.preview.vertex_ray.util.resources import Cluster +from google.cloud.aiplatform.preview.vertex_ray.util.resources import ( + Resources, +) +from google.cloud.aiplatform_v1beta1.types.machine_resources import DiskSpec +from google.cloud.aiplatform_v1beta1.types.machine_resources import ( + MachineSpec, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + PersistentResource, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + ResourcePool, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + ResourceRuntime, +) +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + ResourceRuntimeSpec, +) + + +@dataclasses.dataclass(frozen=True) +class ProjectConstants: + """Defines project-specific constants used by tests.""" + + _TEST_VPC_NETWORK = "mock-vpc-network" + _TEST_GCP_PROJECT_ID = "mock-test-project-id" + _TEST_GCP_PROJECT_ID_OVERRIDE = "mock-test-project-id-2" + _TEST_GCP_REGION = "us-central1" + _TEST_GCP_REGION_OVERRIDE = "asia-east1" + _TEST_GCP_PROJECT_NUMBER = "12345" + _TEST_PARENT = f"projects/{_TEST_GCP_PROJECT_NUMBER}/locations/{_TEST_GCP_REGION}" + _TEST_ARTIFACT_URI = "gs://path/to/artifact/uri" + _TEST_BAD_ARTIFACT_URI = "/path/to/artifact/uri" + _TEST_MODEL_GCS_URI = "gs://test_model_dir" + _TEST_MODEL_ID = ( + f"projects/{_TEST_GCP_PROJECT_NUMBER}/locations/{_TEST_GCP_REGION}/models/456" + ) + + +@dataclasses.dataclass(frozen=True) +class ClusterConstants: + """Defines cluster constants used by tests.""" + + _TEST_VERTEX_RAY_HEAD_NODE_IP = "1.2.3.4:10001" + _TEST_VERTEX_RAY_JOB_CLIENT_IP = "1.2.3.4:8888" + _TEST_VERTEX_RAY_DASHBOARD_URL = ( + "48b400ad90b8dd3c-dot-us-central1.aiplatform-training.googleusercontent.com" + ) + _TEST_VERTEX_RAY_PR_ID = "user-persistent-resource-1234567890" + _TEST_VERTEX_RAY_PR_ADDRESS = ( + f"{ProjectConstants._TEST_PARENT}/persistentResources/" + _TEST_VERTEX_RAY_PR_ID + ) + _TEST_CPU_IMAGE = "us-docker.pkg.dev/vertex-ai/training/ray-cpu.2-4.py310:latest" + _TEST_GPU_IMAGE = "us-docker.pkg.dev/vertex-ai/training/ray-gpu.2-4.py310:latest" + # RUNNING Persistent Cluster w/o Ray + _TEST_RESPONSE_NO_RAY_RUNNING = PersistentResource( + name=_TEST_VERTEX_RAY_PR_ADDRESS, + resource_runtime_spec=ResourceRuntimeSpec(), + resource_runtime=ResourceRuntime(), + state="RUNNING", + ) + # RUNNING + # 1_POOL: merged worker_node_types and head_node_type with duplicate MachineSpec + _TEST_HEAD_NODE_TYPE_1_POOL = Resources( + accelerator_type="NVIDIA_TESLA_P100", accelerator_count=1 + ) + _TEST_WORKER_NODE_TYPES_1_POOL = [ + Resources( + accelerator_type="NVIDIA_TESLA_P100", accelerator_count=1, node_count=2 + ) + ] + _TEST_RESOURCE_POOL_0 = ResourcePool( + id="head-node", + machine_spec=MachineSpec( + machine_type="n1-standard-4", + accelerator_type="NVIDIA_TESLA_P100", + accelerator_count=1, + ), + disk_spec=DiskSpec( + boot_disk_type="pd-ssd", + boot_disk_size_gb=100, + ), + replica_count=3, + ) + _TEST_REQUEST_RUNNING_1_POOL = PersistentResource( + resource_pools=[_TEST_RESOURCE_POOL_0], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec(resource_pool_images={"head-node": _TEST_GPU_IMAGE}), + ), + network=ProjectConstants._TEST_VPC_NETWORK, + ) + # Get response has generated name, and URIs + _TEST_RESPONSE_RUNNING_1_POOL = PersistentResource( + name=_TEST_VERTEX_RAY_PR_ADDRESS, + resource_pools=[_TEST_RESOURCE_POOL_0], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec(resource_pool_images={"head-node": _TEST_GPU_IMAGE}), + ), + network=ProjectConstants._TEST_VPC_NETWORK, + resource_runtime=ResourceRuntime( + access_uris={ + "RAY_DASHBOARD_URI": _TEST_VERTEX_RAY_DASHBOARD_URL, + "RAY_HEAD_NODE_INTERNAL_IP": _TEST_VERTEX_RAY_HEAD_NODE_IP, + } + ), + state="RUNNING", + ) + # 2_POOL: worker_node_types and head_node_type have different MachineSpecs + _TEST_HEAD_NODE_TYPE_2_POOLS = Resources() + _TEST_WORKER_NODE_TYPES_2_POOLS = [ + Resources( + machine_type="n1-standard-16", + node_count=4, + accelerator_type="NVIDIA_TESLA_P100", + accelerator_count=1, + ) + ] + _TEST_RESOURCE_POOL_1 = ResourcePool( + id="head-node", + machine_spec=MachineSpec( + machine_type="n1-standard-4", + ), + disk_spec=DiskSpec( + boot_disk_type="pd-ssd", + boot_disk_size_gb=100, + ), + replica_count=1, + ) + _TEST_RESOURCE_POOL_2 = ResourcePool( + id="worker-pool1", + machine_spec=MachineSpec( + machine_type="n1-standard-16", + accelerator_type="NVIDIA_TESLA_P100", + accelerator_count=1, + ), + disk_spec=DiskSpec( + boot_disk_type="pd-ssd", + boot_disk_size_gb=100, + ), + replica_count=4, + ) + _TEST_REQUEST_RUNNING_2_POOLS = PersistentResource( + resource_pools=[_TEST_RESOURCE_POOL_1, _TEST_RESOURCE_POOL_2], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec( + resource_pool_images={ + "head-node": _TEST_CPU_IMAGE, + "worker-pool1": _TEST_GPU_IMAGE, + } + ), + ), + network=ProjectConstants._TEST_VPC_NETWORK, + ) + _TEST_RESPONSE_RUNNING_2_POOLS = PersistentResource( + name=_TEST_VERTEX_RAY_PR_ADDRESS, + resource_pools=[_TEST_RESOURCE_POOL_1, _TEST_RESOURCE_POOL_2], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec(resource_pool_images={"head-node": _TEST_GPU_IMAGE}), + ), + network=ProjectConstants._TEST_VPC_NETWORK, + resource_runtime=ResourceRuntime( + access_uris={ + "RAY_DASHBOARD_URI": _TEST_VERTEX_RAY_DASHBOARD_URL, + "RAY_HEAD_NODE_INTERNAL_IP": _TEST_VERTEX_RAY_HEAD_NODE_IP, + } + ), + state="RUNNING", + ) + _TEST_CLUSTER = Cluster( + cluster_resource_name=_TEST_VERTEX_RAY_PR_ADDRESS, + python_version="3_10", + ray_version="2_4", + network=ProjectConstants._TEST_VPC_NETWORK, + state="RUNNING", + head_node_type=_TEST_HEAD_NODE_TYPE_1_POOL, + worker_node_types=_TEST_WORKER_NODE_TYPES_1_POOL, + ) + _TEST_CLUSTER_2 = Cluster( + cluster_resource_name=_TEST_VERTEX_RAY_PR_ADDRESS, + python_version="3_10", + ray_version="2_4", + network=ProjectConstants._TEST_VPC_NETWORK, + state="RUNNING", + head_node_type=_TEST_HEAD_NODE_TYPE_2_POOLS, + worker_node_types=_TEST_WORKER_NODE_TYPES_2_POOLS, + ) diff --git a/tests/unit/vertex_ray/test_dashboard_sdk.py b/tests/unit/vertex_ray/test_dashboard_sdk.py new file mode 100644 index 0000000000..05ffd35dbb --- /dev/null +++ b/tests/unit/vertex_ray/test_dashboard_sdk.py @@ -0,0 +1,85 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib + +from google.cloud import aiplatform +from google.cloud.aiplatform.preview import vertex_ray +import test_constants as tc +import mock +import pytest +from ray.dashboard.modules import dashboard_sdk as oss_dashboard_sdk + + +# -*- coding: utf-8 -*- + + +@pytest.fixture +def ray_get_job_submission_client_cluster_info_mock(): + with mock.patch.object( + oss_dashboard_sdk, "get_job_submission_client_cluster_info" + ) as ray_get_job_submission_client_cluster_info_mock: + yield ray_get_job_submission_client_cluster_info_mock + + +@pytest.fixture +def get_persistent_resource_status_running_mock(): + with mock.patch.object( + vertex_ray.util._gapic_utils, "get_persistent_resource" + ) as get_persistent_resource: + get_persistent_resource.return_value = ( + tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL + ) + yield get_persistent_resource + + +class TestGetJobSubmissionClientClusterInfo: + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("get_persistent_resource_status_running_mock") + def test_job_submission_client_cluster_info_with_full_resource_name( + self, + ray_get_job_submission_client_cluster_info_mock, + ): + vertex_ray.get_job_submission_client_cluster_info( + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + ray_get_job_submission_client_cluster_info_mock.assert_called_once_with( + address=tc.ClusterConstants._TEST_VERTEX_RAY_JOB_CLIENT_IP + ) + + @pytest.mark.usefixtures( + "get_persistent_resource_status_running_mock", "google_auth_mock" + ) + def test_job_submission_client_cluster_info_with_cluster_name( + self, + ray_get_job_submission_client_cluster_info_mock, + get_project_number_mock, + ): + aiplatform.init(project=tc.ProjectConstants._TEST_GCP_PROJECT_ID) + + vertex_ray.get_job_submission_client_cluster_info( + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID + ) + get_project_number_mock.assert_called_once_with( + name="projects/{}".format(tc.ProjectConstants._TEST_GCP_PROJECT_ID) + ) + ray_get_job_submission_client_cluster_info_mock.assert_called_once_with( + address=tc.ClusterConstants._TEST_VERTEX_RAY_JOB_CLIENT_IP + ) diff --git a/tests/unit/vertex_ray/test_prediction_utils.py b/tests/unit/vertex_ray/test_prediction_utils.py new file mode 100644 index 0000000000..873f1c252d --- /dev/null +++ b/tests/unit/vertex_ray/test_prediction_utils.py @@ -0,0 +1,70 @@ +"""Test utils for Prediction Tests. +""" + +import numpy as np +import sklearn +from sklearn import linear_model +import tensorflow as tf +import torch +import xgboost + + +def create_tf_model() -> tf.keras.Model: + """Create toy neural network : 1-layer.""" + model = tf.keras.Sequential( + [tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))] + ) + model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"]) + return model + + +def train_tf_model(model: tf.keras.Model) -> None: + """Trains a Keras Model.""" + n = 1 + train_x = np.random.normal(0, 1, size=(n, 4)) + train_y = np.random.uniform(0, 1, size=(n, 1)) + model.fit(train_x, train_y, epochs=1) + + +def get_tensorflow_trained_model() -> tf.keras.Model: + """Returns a tensorflow trained model.""" + model = create_tf_model() + train_tf_model(model) + return model + + +def get_sklearn_estimator() -> sklearn.base.BaseEstimator: + """Returns a sklearn estimator.""" + estimator = linear_model.LinearRegression() + x = [[1, 2], [3, 4], [5, 6]] + y = [7, 8, 9] + estimator.fit(x, y) + return estimator + + +def get_xgboost_model() -> xgboost.XGBClassifier: + train_x = np.array([[1, 2], [3, 4]]) + train_y = np.array([0, 1]) + return xgboost.XGBClassifier().fit(train_x, train_y) + + +input_size = 1 +layer_size = 1 +output_size = 1 +num_epochs = 1 + + +class TorchModel(torch.nn.Module): + def __init__(self): + super(TorchModel, self).__init__() + self.layer1 = torch.nn.Linear(input_size, layer_size) + self.relu = torch.nn.ReLU() + self.layer2 = torch.nn.Linear(layer_size, output_size) + + def forward(self, input_data): + return self.layer2(self.relu(self.layer1(input_data))) + + +def get_pytorch_trained_model() -> torch.nn.Module: + """Returns a pytorch trained model.""" + return TorchModel() diff --git a/tests/unit/vertex_ray/test_ray_prediction.py b/tests/unit/vertex_ray/test_ray_prediction.py new file mode 100644 index 0000000000..a927f91586 --- /dev/null +++ b/tests/unit/vertex_ray/test_ray_prediction.py @@ -0,0 +1,492 @@ +"""Tests for prediction.""" + +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import pickle +import tempfile + +from google.cloud import aiplatform +from google.cloud.aiplatform.preview.vertex_ray.predict import ( + sklearn as prediction_sklearn, +) +from google.cloud.aiplatform.preview.vertex_ray.predict import ( + tensorflow as prediction_tensorflow, +) +from google.cloud.aiplatform.preview.vertex_ray.predict import ( + torch as prediction_torch, +) +from google.cloud.aiplatform.preview.vertex_ray.predict import ( + xgboost as prediction_xgboost, +) +from google.cloud.aiplatform.utils import gcs_utils +import test_constants as tc +import test_prediction_utils + +import mock +import numpy as np +import pytest +import ray +from ray.train import xgboost as ray_xgboost +import tensorflow as tf +import torch +import xgboost + + +@pytest.fixture() +def upload_tensorflow_saved_model_mock(): + with mock.patch.object( + aiplatform.Model, "upload_tensorflow_saved_model" + ) as upload_tensorflow_saved_model_mock: + upload_tensorflow_saved_model_mock.return_value = None + yield upload_tensorflow_saved_model_mock + + +@pytest.fixture() +def ray_tensorflow_checkpoint(): + defined_model = test_prediction_utils.get_tensorflow_trained_model() + checkpoint = ray.train.tensorflow.TensorflowCheckpoint.from_model(defined_model) + return checkpoint + + +@pytest.fixture() +def ray_checkpoint_from_dict(): + checkpoint_data = {"data": 123} + checkpoint = ray.air.checkpoint.Checkpoint.from_dict(checkpoint_data) + return checkpoint + + +@pytest.fixture() +def save_tf_model(): + with mock.patch.object(tf.keras.Model, "save") as save_tf_model_mock: + save_tf_model_mock.return_value = None + yield save_tf_model_mock + + +@pytest.fixture() +def ray_sklearn_checkpoint(): + estimator = test_prediction_utils.get_sklearn_estimator() + temp_dir = tempfile.mkdtemp() + checkpoint = ray.train.sklearn.SklearnCheckpoint.from_estimator( + estimator, path=temp_dir + ) + return checkpoint + + +@pytest.fixture() +def ray_xgboost_checkpoint(): + model = test_prediction_utils.get_xgboost_model() + checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster()) + return checkpoint + + +@pytest.fixture() +def pickle_dump(): + with mock.patch.object(pickle, "dump") as pickle_dump: + pickle_dump.return_value = None + yield pickle_dump + + +@pytest.fixture +def mock_vertex_model(): + model = mock.MagicMock(aiplatform.Model) + model.uri = tc.ProjectConstants._TEST_MODEL_GCS_URI + model.container_spec.image_uri = "us-docker.xxx/sklearn-cpu.1-0:latest" + model.labels = {"registered_by_vertex_ai": "true"} + yield model + + +@pytest.fixture() +def upload_sklearn_mock(mock_vertex_model): + with mock.patch.object( + aiplatform.Model, "upload_scikit_learn_model_file" + ) as upload_sklearn_mock: + upload_sklearn_mock.return_value = mock_vertex_model + yield upload_sklearn_mock + + +@pytest.fixture +def mock_xgboost_vertex_model(): + model = mock.MagicMock(aiplatform.Model) + model.uri = tc.ProjectConstants._TEST_MODEL_GCS_URI + model.container_spec.image_uri = "us-docker.xxx/xgboost-cpu.1-6:latest" + model.labels = {"registered_by_vertex_ai": "true"} + yield model + + +@pytest.fixture() +def upload_xgboost_mock(mock_xgboost_vertex_model): + with mock.patch.object( + aiplatform.Model, "upload_xgboost_model_file" + ) as upload_xgboost_mock: + upload_xgboost_mock.return_value = mock_xgboost_vertex_model + yield upload_xgboost_mock + + +@pytest.fixture() +def gcs_utils_upload_to_gcs(): + with mock.patch.object(gcs_utils, "upload_to_gcs") as gcs_utils_upload_to_gcs: + gcs_utils.return_value = None + yield gcs_utils_upload_to_gcs + + +@pytest.fixture() +def ray_torch_checkpoint(): + defined_model = test_prediction_utils.get_pytorch_trained_model() + checkpoint = ray.train.torch.TorchCheckpoint.from_model(defined_model) + return checkpoint + + +@pytest.mark.usefixtures("google_auth_mock") +class TestPredictionFunctionality: + """Tests for Prediction.""" + + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + # Tensorflow Tests + def test_convert_checkpoint_to_tf_model_raise_exception( + self, ray_checkpoint_from_dict + ) -> None: + """Test if a checkpoint is not an instance of TensflowCheckpoint should + fail with exception ValueError.""" + with pytest.raises(ValueError) as ve: + prediction_tensorflow.register._get_tensorflow_model_from( + ray_checkpoint_from_dict + ) + + assert ve.match( + regexp=r".* arg checkpoint should be a " + "ray.train.tensorflow.TensorflowCheckpoint .*" + ) + + def test_convert_checkpoint_to_tensorflow_model_succeed( + self, ray_tensorflow_checkpoint + ) -> None: + """Test if a TensorflowCheckpoint conversion is successful.""" + # Act + model = prediction_tensorflow.register._get_tensorflow_model_from( + ray_tensorflow_checkpoint, model=test_prediction_utils.create_tf_model + ) + + # Assert + assert model is not None + values = model.predict([[1, 1, 1, 1]]) + assert values[0] is not None + + def test_register_tensorflow_succeed( + self, + ray_tensorflow_checkpoint, + upload_tensorflow_saved_model_mock, + save_tf_model, + ) -> None: + """Test if a TensorflowCheckpoint upload is successful.""" + # Act + prediction_tensorflow.register_tensorflow( + ray_tensorflow_checkpoint, + artifact_uri=tc.ProjectConstants._TEST_ARTIFACT_URI, + _model=test_prediction_utils.create_tf_model, + use_gpu=False, + ) + + # Assert + upload_tensorflow_saved_model_mock.assert_called_once() + save_tf_model.assert_called_once_with( + f"{tc.ProjectConstants._TEST_ARTIFACT_URI}/ray-on-vertex-registered-tensorflow-model" + ) + + def test_register_tensorflow_initialized_succeed( + self, + ray_tensorflow_checkpoint, + upload_tensorflow_saved_model_mock, + save_tf_model, + ) -> None: + """Test if a TensorflowCheckpoint upload is successful when artifact_uri is None but initialized.""" + aiplatform.init( + project=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + # Act + prediction_tensorflow.register_tensorflow( + ray_tensorflow_checkpoint, + _model=test_prediction_utils.create_tf_model, + use_gpu=False, + ) + + # Assert + upload_tensorflow_saved_model_mock.assert_called_once() + save_tf_model.assert_called_once_with( + f"{tc.ProjectConstants._TEST_ARTIFACT_URI}/ray-on-vertex-registered-tensorflow-model" + ) + + def test_register_tensorflowartifact_uri_is_none_raise_error( + self, ray_tensorflow_checkpoint + ) -> None: + """Test if a TensorflowCheckpoint upload gives ValueError.""" + # Act and Assert + with pytest.raises(ValueError) as ve: + prediction_tensorflow.register_tensorflow( + checkpoint=ray_tensorflow_checkpoint, + artifact_uri=None, + _model=test_prediction_utils.create_tf_model, + ) + assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") + + def test_register_tensorflowartifact_uri_not_gcs_uri_raise_error( + self, ray_tensorflow_checkpoint + ) -> None: + """Test if a TensorflowCheckpoint upload gives ValueError.""" + # Act and Assert + with pytest.raises(ValueError) as ve: + prediction_tensorflow.register_tensorflow( + checkpoint=ray_tensorflow_checkpoint, + artifact_uri=tc.ProjectConstants._TEST_BAD_ARTIFACT_URI, + _model=test_prediction_utils.create_tf_model, + ) + assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") + + # Sklearn Tests + def test_convert_checkpoint_to_sklearn_raise_exception( + self, ray_checkpoint_from_dict + ) -> None: + """Test if a checkpoint is not an instance of SklearnCheckpoint should + fail with exception ValueError.""" + + with pytest.raises(ValueError) as ve: + prediction_sklearn.register._get_estimator_from(ray_checkpoint_from_dict) + assert ve.match( + regexp=r".* arg checkpoint should be a " + "ray.train.sklearn.SklearnCheckpoint .*" + ) + + def test_convert_checkpoint_to_sklearn_model_succeed( + self, ray_sklearn_checkpoint + ) -> None: + """Test if a SklearnCheckpoint conversion is successful.""" + # Act + estimator = prediction_sklearn.register._get_estimator_from( + ray_sklearn_checkpoint + ) + + # Assert + assert estimator is not None + y_pred = estimator.predict([[10, 11]]) + assert y_pred[0] is not None + + def test_register_sklearn_succeed( + self, + ray_sklearn_checkpoint, + upload_sklearn_mock, + pickle_dump, + gcs_utils_upload_to_gcs, + ) -> None: + """Test if a SklearnCheckpoint upload is successful.""" + # Act + vertex_ai_model = prediction_sklearn.register_sklearn( + ray_sklearn_checkpoint, + artifact_uri=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + + # Assert + vertex_ai_model.uri = tc.ProjectConstants._TEST_MODEL_GCS_URI + vertex_ai_model.container_spec.image_uri = ( + "us-docker.xxx/sklearn-cpu.1-0:latest" + ) + upload_sklearn_mock.assert_called_once() + pickle_dump.assert_called_once() + gcs_utils_upload_to_gcs.assert_called_once() + + def test_register_sklearn_initialized_succeed( + self, + ray_sklearn_checkpoint, + upload_sklearn_mock, + pickle_dump, + gcs_utils_upload_to_gcs, + ) -> None: + """Test if a SklearnCheckpoint upload is successful when artifact_uri is None but initialized.""" + aiplatform.init( + project=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + # Act + vertex_ai_model = prediction_sklearn.register_sklearn( + ray_sklearn_checkpoint, + ) + + # Assert + vertex_ai_model.uri = tc.ProjectConstants._TEST_MODEL_GCS_URI + vertex_ai_model.container_spec.image_uri = ( + "us-docker.xxx/sklearn-cpu.1-0:latest" + ) + upload_sklearn_mock.assert_called_once() + pickle_dump.assert_called_once() + gcs_utils_upload_to_gcs.assert_called_once() + + def test_register_sklearnartifact_uri_is_none_raise_error( + self, ray_sklearn_checkpoint + ) -> None: + """Test if a SklearnCheckpoint upload gives ValueError.""" + # Act and Assert + with pytest.raises(ValueError) as ve: + prediction_sklearn.register_sklearn( + checkpoint=ray_sklearn_checkpoint, + artifact_uri=None, + ) + assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") + + def test_register_sklearnartifact_uri_not_gcs_uri_raise_error( + self, ray_sklearn_checkpoint + ) -> None: + """Test if a SklearnCheckpoint upload gives ValueError.""" + # Act and Assert + with pytest.raises(ValueError) as ve: + prediction_sklearn.register_sklearn( + checkpoint=ray_sklearn_checkpoint, + artifact_uri=tc.ProjectConstants._TEST_BAD_ARTIFACT_URI, + ) + assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") + + # XGBoost Tests + def test_convert_checkpoint_to_xgboost_raise_exception( + self, ray_checkpoint_from_dict + ) -> None: + """Test if a checkpoint is not an instance of XGBoostCheckpoint should + + fail with exception ValueError. + """ + + with pytest.raises(ValueError) as ve: + prediction_xgboost.register._get_xgboost_model_from( + ray_checkpoint_from_dict + ) + assert ve.match( + regexp=r".* arg checkpoint should be a " + "ray.train.xgboost.XGBoostCheckpoint .*" + ) + + def test_convert_checkpoint_to_xgboost_model_succeed( + self, ray_xgboost_checkpoint + ) -> None: + """Test if a XGBoostCheckpoint conversion is successful.""" + # Act + model = prediction_xgboost.register._get_xgboost_model_from( + ray_xgboost_checkpoint + ) + + # Assert + assert model is not None + y_pred = model.predict(xgboost.DMatrix(np.array([[1, 2]]))) + assert y_pred[0] is not None + + def test_register_xgboost_succeed( + self, + ray_xgboost_checkpoint, + upload_xgboost_mock, + pickle_dump, + gcs_utils_upload_to_gcs, + ) -> None: + """Test if a XGBoostCheckpoint upload is successful.""" + # Act + vertex_ai_model = prediction_xgboost.register_xgboost( + ray_xgboost_checkpoint, + artifact_uri=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + + # Assert + vertex_ai_model.uri = tc.ProjectConstants._TEST_MODEL_GCS_URI + vertex_ai_model.container_spec.image_uri = ( + "us-docker.xxx/xgboost-cpu.1-6:latest" + ) + upload_xgboost_mock.assert_called_once() + pickle_dump.assert_called_once() + gcs_utils_upload_to_gcs.assert_called_once() + + def test_register_xgboost_initialized_succeed( + self, + ray_xgboost_checkpoint, + upload_xgboost_mock, + pickle_dump, + gcs_utils_upload_to_gcs, + ) -> None: + """Test if a XGBoostCheckpoint upload is successful when artifact_uri is None but initialized.""" + aiplatform.init( + project=tc.ProjectConstants._TEST_GCP_PROJECT_ID, + staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI, + ) + # Act + vertex_ai_model = prediction_xgboost.register_xgboost( + ray_xgboost_checkpoint, + ) + + # Assert + vertex_ai_model.uri = tc.ProjectConstants._TEST_MODEL_GCS_URI + vertex_ai_model.container_spec.image_uri = ( + "us-docker.xxx/xgboost-cpu.1-6:latest" + ) + upload_xgboost_mock.assert_called_once() + pickle_dump.assert_called_once() + gcs_utils_upload_to_gcs.assert_called_once() + + def test_register_xgboostartifact_uri_is_none_raise_error( + self, ray_xgboost_checkpoint + ) -> None: + """Test if a XGBoostCheckpoint upload gives ValueError.""" + # Act and Assert + with pytest.raises(ValueError) as ve: + prediction_xgboost.register_xgboost( + checkpoint=ray_xgboost_checkpoint, + artifact_uri=None, + ) + assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") + + def test_register_xgboostartifact_uri_not_gcs_uri_raise_error( + self, ray_xgboost_checkpoint + ) -> None: + """Test if a XGBoostCheckpoint upload gives ValueError.""" + # Act and Assert + with pytest.raises(ValueError) as ve: + prediction_xgboost.register_xgboost( + checkpoint=ray_xgboost_checkpoint, + artifact_uri=tc.ProjectConstants._TEST_BAD_ARTIFACT_URI, + ) + assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*") + + # Pytorch Tests + def test_convert_checkpoint_to_torch_model_raises_exception( + self, ray_checkpoint_from_dict + ) -> None: + """Test if a checkpoint is not an instance of TorchCheckpoint should + fail with exception ValueError.""" + with pytest.raises(ValueError): + prediction_torch.register.get_pytorch_model_from(ray_checkpoint_from_dict) + + def test_convert_checkpoint_to_pytorch_model_succeed( + self, ray_torch_checkpoint + ) -> None: + """Test if a TorchCheckpoint conversion is successful.""" + # Act + model = prediction_torch.register.get_pytorch_model_from(ray_torch_checkpoint) + + # Assert + assert model is not None + values = model(torch.tensor([10000], dtype=torch.float)) + print(values[0]) + assert values[0] is not None diff --git a/tests/unit/vertex_ray/test_ray_utils.py b/tests/unit/vertex_ray/test_ray_utils.py new file mode 100644 index 0000000000..602314a1c9 --- /dev/null +++ b/tests/unit/vertex_ray/test_ray_utils.py @@ -0,0 +1,46 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.preview import vertex_ray +import test_constants as tc +import pytest + + +class TestUtils: + def test_get_persistent_resource_success(self, persistent_client_mock): + response = vertex_ray.util._gapic_utils.get_persistent_resource( + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + + persistent_client_mock.assert_called_once() + assert response == tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL + + def test_get_persistent_resource_stopping(self, persistent_client_stopping_mock): + with pytest.raises(RuntimeError) as e: + vertex_ray.util._gapic_utils.get_persistent_resource( + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + + persistent_client_stopping_mock.assert_called_once() + e.match(regexp=r"The cluster is stopping.") + + def test_get_persistent_resource_error(self, persistent_client_error_mock): + with pytest.raises(RuntimeError) as e: + vertex_ray.util._gapic_utils.get_persistent_resource( + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ) + + persistent_client_error_mock.assert_called_once() + e.match(regexp=r"The cluster encountered an error.") diff --git a/tests/unit/vertex_ray/test_vertex_ray_client.py b/tests/unit/vertex_ray/test_vertex_ray_client.py new file mode 100644 index 0000000000..1ea55e35c5 --- /dev/null +++ b/tests/unit/vertex_ray/test_vertex_ray_client.py @@ -0,0 +1,161 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib + +from google.cloud import aiplatform +from google.cloud.aiplatform.preview import vertex_ray +import test_constants as tc +import mock +import pytest +import ray + + +# -*- coding: utf-8 -*- + +_TEST_CLIENT_CONTEXT = ray.client_builder.ClientContext( + dashboard_url=tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_URL, + python_version="MOCK_PYTHON_VERSION", + ray_version="MOCK_RAY_VERSION", + ray_commit="MOCK_RAY_COMMIT", + protocol_version="MOCK_PROTOCOL_VERSION", + _num_clients=1, + _context_to_restore=None, +) + +_TEST_VERTEX_RAY_CLIENT_CONTEXT = vertex_ray.client_builder._VertexRayClientContext( + persistent_resource_id="MOCK_PERSISTENT_RESOURCE_ID", + ray_head_uris={ + "RAY_DASHBOARD_URI": tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_URL, + "RAY_HEAD_NODE_INTERNAL_IP": tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP, + }, + ray_client_context=_TEST_CLIENT_CONTEXT, +) + + +@pytest.fixture +def ray_client_init_mock(): + with mock.patch.object(ray.ClientBuilder, "__init__") as ray_client_init: + ray_client_init.return_value = None + yield ray_client_init + + +@pytest.fixture +def ray_client_connect_mock(): + with mock.patch.object(ray.ClientBuilder, "connect") as ray_client_connect: + ray_client_connect.return_value = _TEST_CLIENT_CONTEXT + yield ray_client_connect + + +@pytest.fixture +def get_persistent_resource_status_running_mock(): + with mock.patch.object( + vertex_ray.util._gapic_utils, "get_persistent_resource" + ) as resolve_head_ip: + resolve_head_ip.return_value = tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL + yield resolve_head_ip + + +@pytest.fixture +def get_persistent_resource_status_running_no_ray_mock(): + with mock.patch.object( + vertex_ray.util._gapic_utils, "get_persistent_resource" + ) as resolve_head_ip: + resolve_head_ip.return_value = tc.ClusterConstants._TEST_RESPONSE_NO_RAY_RUNNING + yield resolve_head_ip + + +class TestClientBuilder: + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("get_persistent_resource_status_running_mock") + def test_init_with_full_resource_name( + self, + ray_client_init_mock, + ): + vertex_ray.ClientBuilder(tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS) + ray_client_init_mock.assert_called_once_with( + tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP, + ) + + @pytest.mark.usefixtures( + "get_persistent_resource_status_running_mock", "google_auth_mock" + ) + def test_init_with_cluster_name( + self, + ray_client_init_mock, + get_project_number_mock, + ): + aiplatform.init(project=tc.ProjectConstants._TEST_GCP_PROJECT_ID) + + vertex_ray.ClientBuilder(tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID) + get_project_number_mock.assert_called_once_with( + name="projects/{}".format(tc.ProjectConstants._TEST_GCP_PROJECT_ID) + ) + ray_client_init_mock.assert_called_once_with( + tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP, + ) + + @pytest.mark.usefixtures("get_persistent_resource_status_running_mock") + def test_connect_running(self, ray_client_connect_mock): + connect_result = vertex_ray.ClientBuilder( + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ).connect() + ray_client_connect_mock.assert_called_once_with() + assert connect_result == _TEST_VERTEX_RAY_CLIENT_CONTEXT + assert ( + connect_result.persistent_resource_id + == tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID + ) + + @pytest.mark.usefixtures("get_persistent_resource_status_running_no_ray_mock") + def test_connect_running_no_ray(self, ray_client_connect_mock): + expected_message = ( + "Ray Cluster ", + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID, + " failed to start Head node properly.", + ) + with pytest.raises(ValueError) as exception: + vertex_ray.ClientBuilder( + tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS + ).connect() + + ray_client_connect_mock.assert_called_once_with() + assert str(exception.value) == expected_message + + @pytest.mark.parametrize( + "address", + [ + "bad/format/address", + "must/have/exactly/five/backslashes/no/more/or/less", + "do/not/append/a/trailing/backslash/", + tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP, # cannot input raw head node ip + ], + ) + def test_bad_format_address(self, address): + expected_message = ( + "[Ray on Vertex AI]: Address must be in the following format: " + "vertex_ray://projects//locations//" + "persistentResources/ or vertex_ray://." + ) + + with pytest.raises(ValueError) as exception: + vertex_ray.ClientBuilder(address) + + assert str(exception.value) == expected_message diff --git a/tests/unit/vertexai/test_remote_training.py b/tests/unit/vertexai/test_remote_training.py index b8b0874078..218ae74ef1 100644 --- a/tests/unit/vertexai/test_remote_training.py +++ b/tests/unit/vertexai/test_remote_training.py @@ -1533,7 +1533,9 @@ def test_remote_training_keras_distributed_no_cuda_no_worker_pool_specs( # TODO(b/300116902) Remove this once we find better solution. @pytest.mark.xfail( - sys.version_info.minor == 11, raises=ValueError, reason="Flaky in python 3.11" + sys.version_info.minor >= 8, + raises=ValueError, + reason="Flaky in python 3.8, 3.10, 3.11", ) @pytest.mark.usefixtures( "list_default_tensorboard_mock", @@ -1615,7 +1617,9 @@ def test_remote_training_sklearn_with_experiment( # TODO(b/300116902) Remove this once we find better solution @pytest.mark.xfail( - sys.version_info.minor == 11, raises=ValueError, reason="Flaky in python 3.11" + sys.version_info.minor >= 8, + raises=ValueError, + reason="Flaky in python 3.8, 3.10, 3.11", ) @pytest.mark.usefixtures( "list_default_tensorboard_mock",