From 6d5c7c42d1c352f161e4738c6dbbf540a032017b Mon Sep 17 00:00:00 2001 From: mkovalski Date: Mon, 29 Nov 2021 12:41:02 -0500 Subject: [PATCH] feat: Add cloud profiler to training_utils --- README.rst | 18 + .../cloud/aiplatform/tensorboard/__init__.py | 2 +- ...tensorboard.py => tensorboard_resource.py} | 0 .../aiplatform/tensorboard/uploader_utils.py | 1 - .../training_utils/cloud_profiler/README.rst | 20 + .../training_utils/cloud_profiler/__init__.py | 35 ++ .../cloud_profiler/initializer.py | 118 +++++ .../cloud_profiler/plugins/base_plugin.py | 71 +++ .../plugins/tensorflow/tensorboard_api.py | 195 ++++++++ .../plugins/tensorflow/tf_profiler.py | 358 +++++++++++++++ .../cloud_profiler/webserver.py | 114 +++++ .../cloud_profiler/wsgi_types.py | 28 ++ .../training_utils/environment_variables.py | 5 + setup.py | 7 +- tests/unit/aiplatform/test_cloud_profiler.py | 434 ++++++++++++++++++ tests/unit/aiplatform/test_training_utils.py | 11 + 16 files changed, 1414 insertions(+), 3 deletions(-) rename google/cloud/aiplatform/tensorboard/{tensorboard.py => tensorboard_resource.py} (100%) create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/README.rst create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/webserver.py create mode 100644 google/cloud/aiplatform/training_utils/cloud_profiler/wsgi_types.py create mode 100644 tests/unit/aiplatform/test_cloud_profiler.py diff --git a/README.rst b/README.rst index b9e9ca4937..a59857a13b 100644 --- a/README.rst +++ b/README.rst @@ -464,6 +464,24 @@ To use Explanation Metadata in endpoint deployment and model upload: aiplatform.Model.upload(..., explanation_metadata=explanation_metadata) +Cloud Profiler +---------------------------- + +Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex Tensorboard. + +To start using the profiler with TensorFlow, update your training script to include the following: + +.. code-block:: Python + + from google.cloud.aiplatform.training_utils import cloud_profiler + ... + cloud_profiler.init() + +Next, run the job with with a Vertex TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview + +Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs. + + Next Steps ~~~~~~~~~~ diff --git a/google/cloud/aiplatform/tensorboard/__init__.py b/google/cloud/aiplatform/tensorboard/__init__.py index 93c48cd46c..f4b1c0b105 100644 --- a/google/cloud/aiplatform/tensorboard/__init__.py +++ b/google/cloud/aiplatform/tensorboard/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. # -from google.cloud.aiplatform.tensorboard.tensorboard import Tensorboard +from google.cloud.aiplatform.tensorboard.tensorboard_resource import Tensorboard __all__ = ("Tensorboard",) diff --git a/google/cloud/aiplatform/tensorboard/tensorboard.py b/google/cloud/aiplatform/tensorboard/tensorboard_resource.py similarity index 100% rename from google/cloud/aiplatform/tensorboard/tensorboard.py rename to google/cloud/aiplatform/tensorboard/tensorboard_resource.py diff --git a/google/cloud/aiplatform/tensorboard/uploader_utils.py b/google/cloud/aiplatform/tensorboard/uploader_utils.py index 86712a5542..1396f6cc78 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_utils.py +++ b/google/cloud/aiplatform/tensorboard/uploader_utils.py @@ -406,7 +406,6 @@ def get_or_create( filter="display_name = {}".format(json.dumps(str(tag_name))), ) ) - num = 0 time_series = None diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/README.rst b/google/cloud/aiplatform/training_utils/cloud_profiler/README.rst new file mode 100644 index 0000000000..6c6cfc1af9 --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/README.rst @@ -0,0 +1,20 @@ +Cloud Profiler +================================= + +Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex Tensorboard. + +Quick Start +------------ + +To start using the profiler with TensorFlow, update your training script to include the following: + +.. code-block:: Python + + from google.cloud.aiplatform.training_utils import cloud_profiler + ... + cloud_profiler.init() + + +Next, run the job with with a Vertex TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview + +Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs. diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py b/google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py new file mode 100644 index 0000000000..f5aa40cc34 --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +try: + import google.cloud.aiplatform.training_utils.cloud_profiler.initializer as initializer +except ImportError as err: + raise ImportError( + "Could not load the cloud profiler. To use the profiler, " + 'install the SDK using "pip install google-cloud-aiplatform[cloud-profiler]"' + ) from err + +""" +Initialize the cloud profiler for tensorflow. + +Usage: +from google.cloud.aiplatform.training_utils import cloud_profiler + +cloud_profiler.init(profiler='tensorflow') +""" + +init = initializer.initialize diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py b/google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py new file mode 100644 index 0000000000..1a098dd2a5 --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import threading +from typing import Optional, Type +from werkzeug import serving + +from google.cloud.aiplatform.training_utils import environment_variables +from google.cloud.aiplatform.training_utils.cloud_profiler import webserver +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( + tf_profiler, +) + +# Mapping of available plugins to use +_AVAILABLE_PLUGINS = {"tensorflow": tf_profiler.TFProfiler} + + +class MissingEnvironmentVariableException(Exception): + pass + + +def _build_plugin( + plugin: Type[base_plugin.BasePlugin], +) -> Optional[base_plugin.BasePlugin]: + """Builds the plugin given the object. + + Args: + plugin (Type[base_plugin]): + Required. An uninitialized plugin class. + + Returns: + An initialized plugin, or None if plugin cannot be + initialized. + """ + if not plugin.can_initialize(): + logging.warning("Cannot initialize the plugin") + return + + plugin.setup() + + if not plugin.post_setup_check(): + return + + return plugin() + + +def _run_app_thread(server: webserver.WebServer, port: int): + """Run the webserver in a separate thread. + + Args: + server (webserver.WebServer): + Required. A webserver to accept requests. + port (int): + Required. The port to run the webserver on. + """ + daemon = threading.Thread( + name="profile_server", + target=serving.run_simple, + args=("0.0.0.0", port, server,), + ) + daemon.setDaemon(True) + daemon.start() + + +def initialize(plugin: str = "tensorflow"): + """Initializes the profiling SDK. + + Args: + plugin (str): + Required. Name of the plugin to initialize. + Current options are ["tensorflow"] + + Raises: + ValueError: + The plugin does not exist. + MissingEnvironmentVariableException: + An environment variable that is needed is not set. + """ + plugin_obj = _AVAILABLE_PLUGINS.get(plugin) + + if not plugin_obj: + raise ValueError( + "Plugin {} not available, must choose from {}".format( + plugin, _AVAILABLE_PLUGINS.keys() + ) + ) + + prof_plugin = _build_plugin(plugin_obj) + + if prof_plugin is None: + return + + server = webserver.WebServer([prof_plugin]) + + if not environment_variables.http_handler_port: + raise MissingEnvironmentVariableException( + "'AIP_HTTP_HANDLER_PORT' must be set." + ) + + port = int(environment_variables.http_handler_port) + + _run_app_thread(server, port) diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py new file mode 100644 index 0000000000..67b6b40ae9 --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +from typing import Callable, Dict +from werkzeug import Response + + +class BasePlugin(abc.ABC): + """Base plugin for cloud training tools endpoints. + + The plugins support registering http handlers to be used for + AI Platform training jobs. + """ + + @staticmethod + @abc.abstractmethod + def setup() -> None: + """Run any setup code for the plugin before webserver is launched.""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def can_initialize() -> bool: + """Check whether a plugin is able to be initialized. + + Used for checking if correct dependencies are installed, system requirements, etc. + + Returns: + Bool indicating whether the plugin can be initialized. + """ + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def post_setup_check() -> bool: + """Check if after initialization, we need to use the plugin. + + Example: Web server only needs to run for main node for training, others + just need to have 'setup()' run to start the rpc server. + + Returns: + A boolean indicating whether post setup checks pass. + """ + raise NotImplementedError + + @abc.abstractmethod + def get_routes(self) -> Dict[str, Callable[..., Response]]: + """Get the mapping from path to handler. + + This is the method in which plugins can assign different routes to + different handlers. + + Returns: + A mapping from a route to a handler. + """ + raise NotImplementedError diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py new file mode 100644 index 0000000000..4da8381b4c --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Helpers for creating a profile request sender for tf profiler plugin.""" + +import os +import re +from typing import Tuple + +from tensorboard.uploader import upload_tracker +from tensorboard.uploader import util +from tensorboard.uploader.proto import server_info_pb2 +from tensorboard.util import tb_logging + +from google.api_core import exceptions +from google.cloud import aiplatform +from google.cloud import storage +from google.cloud.aiplatform.utils import TensorboardClientWithOverride +from google.cloud.aiplatform.tensorboard import uploader_utils +from google.cloud.aiplatform.compat.types import ( + tensorboard_experiment_v1beta1 as tensorboard_experiment, +) +from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader +from google.cloud.aiplatform import training_utils + +logger = tb_logging.get_logger() + + +def _get_api_client() -> TensorboardClientWithOverride: + """Creates an Tensorboard API client.""" + aiplatform.constants.API_BASE_PATH = ( + training_utils.environment_variables.tensorboard_api_uri + ) + m = re.match( + "projects/.*/locations/(.*)/tensorboards/.*", + training_utils.environment_variables.tensorboard_resource_name, + ) + region = m[1] + + api_client = aiplatform.initializer.global_config.create_client( + client_class=TensorboardClientWithOverride, location_override=region, + ) + + return api_client + + +def _get_project_id() -> str: + """Gets the project id from the tensorboard resource name. + + Returns: + Project ID for current project. + + Raises: + ValueError: Cannot parse the tensorboard resource name. + """ + m = re.match( + "projects/(.*)/locations/.*/tensorboards/.*", + training_utils.environment_variables.tensorboard_resource_name, + ) + if not m: + raise ValueError( + "Incorrect format for tensorboard resource name: %s", + training_utils.environment_variables.tensorboard_resource_name, + ) + return m[1] + + +def _make_upload_limits() -> server_info_pb2.UploadLimits: + """Creates the upload limits for tensorboard. + + Returns: + An UploadLimits object. + """ + upload_limits = server_info_pb2.UploadLimits() + upload_limits.min_blob_request_interval = 10 + upload_limits.max_blob_request_size = 4 * (2 ** 20) - 256 * (2 ** 10) + upload_limits.max_blob_size = 10 * (2 ** 30) # 10GiB + + return upload_limits + + +def _get_blob_items( + api_client: TensorboardClientWithOverride, +) -> Tuple[storage.bucket.Bucket, str]: + """Gets the blob storage items for the tensorboard resource. + + Args: + api_client (): + Required. Client go get information about the tensorboard instance. + + Returns: + A tuple of storage buckets and the blob storage folder name. + """ + project_id = _get_project_id() + tensorboard = api_client.get_tensorboard( + name=training_utils.environment_variables.tensorboard_resource_name + ) + + path_prefix = tensorboard.blob_storage_path_prefix + "/" + first_slash_index = path_prefix.find("/") + bucket_name = path_prefix[:first_slash_index] + blob_storage_bucket = storage.Client(project=project_id).bucket(bucket_name) + blob_storage_folder = path_prefix[first_slash_index + 1 :] + + return blob_storage_bucket, blob_storage_folder + + +def _get_or_create_experiment( + api: TensorboardClientWithOverride, experiment_name: str +) -> str: + """Creates a tensorboard experiment. + + Args: + api (TensorboardClientWithOverride): + Required. An api for interfacing with tensorboard resources. + experiment_name (str): + Required. The name of the experiment to get or create. + + Returns: + The name of the experiment. + """ + tb_experiment = tensorboard_experiment.TensorboardExperiment() + + try: + experiment = api.create_tensorboard_experiment( + parent=training_utils.environment_variables.tensorboard_resource_name, + tensorboard_experiment=tb_experiment, + tensorboard_experiment_id=experiment_name, + ) + except exceptions.AlreadyExists: + logger.info("Creating experiment failed. Retrieving experiment.") + experiment_name = os.path.join( + training_utils.environment_variables.tensorboard_resource_name, + "experiments", + experiment_name, + ) + experiment = api.get_tensorboard_experiment(name=experiment_name) + + return experiment.name + + +def create_profile_request_sender() -> profile_uploader.ProfileRequestSender: + """Creates the `ProfileRequestSender` for the profile plugin. + + A profile request sender is created for the plugin so that after profiling runs + have finished, data can be uploaded to the tensorboard backend. + + Returns: + A ProfileRequestSender object. + """ + api_client = _get_api_client() + + experiment_name = _get_or_create_experiment( + api_client, training_utils.environment_variables.cloud_ml_job_id + ) + + upload_limits = _make_upload_limits() + + blob_rpc_rate_limiter = util.RateLimiter( + upload_limits.min_blob_request_interval / 100 + ) + + blob_storage_bucket, blob_storage_folder = _get_blob_items(api_client,) + + source_bucket = uploader_utils.get_source_bucket( + training_utils.environment_variables.tensorboard_log_dir + ) + + profile_request_sender = profile_uploader.ProfileRequestSender( + experiment_name, + api_client, + upload_limits=upload_limits, + blob_rpc_rate_limiter=blob_rpc_rate_limiter, + blob_storage_bucket=blob_storage_bucket, + blob_storage_folder=blob_storage_folder, + source_bucket=source_bucket, + tracker=upload_tracker.UploadTracker(verbosity=1), + logdir=training_utils.environment_variables.tensorboard_log_dir, + ) + + return profile_request_sender diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py new file mode 100644 index 0000000000..514ae19368 --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A plugin to handle remote tensoflow profiler sessions for Vertex AI.""" + +import argparse +from collections import namedtuple +import importlib.util +import json +import logging +import tensorboard.plugins.base_plugin as tensorboard_base_plugin +from typing import Callable, Dict, Optional +from urllib import parse +from werkzeug import Response + +from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader +from google.cloud.aiplatform.training_utils import environment_variables +from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( + tensorboard_api, +) + + +# TF verison information. +Version = namedtuple("Version", ["major", "minor", "patch"]) + +logger = logging.Logger("tf-profiler") + +_BASE_TB_ENV_WARNING = ( + "To set this environment variable, run your training with the 'tensorboard' " + "option. For more information on how to run with training with tensorboard, visit " + "https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training" +) + + +def _get_tf_versioning() -> Optional[Version]: + """Convert version string to a Version namedtuple for ease of parsing. + + Returns: + A version object if finding the version was successful, None otherwise. + """ + import tensorflow as tf + + version = tf.__version__ + + versioning = version.split(".") + if len(versioning) != 3: + return + + return Version(int(versioning[0]), int(versioning[1]), int(versioning[2])) + + +def _is_compatible_version(version: Version) -> bool: + """Check if version is compatible with tf profiling. + + Profiling plugin is available to be used for version >= 2.4.0. + While the profiler is available in 2.2.0 >=, some additional dependencies + that are included in 2.4.0 >= are also needed for the tensorboard-plugin-profile. + + Profiler: + https://www.tensorflow.org/guide/profiler + Required commit for tensorboard-plugin-profile: + https://github.com/tensorflow/tensorflow/commit/8b9c207242db515daef033e74d69ea5d8e023dc6 + + Args: + version (Version): + Required. `Verison` of tensorflow. + + Returns: + Bool indicating wheter version is compatible with profiler. + """ + return version.major >= 2 and version.minor >= 4 + + +def _check_tf() -> bool: + """Check whether all the tensorflow prereqs are met. + + Returns: + True if all requirements met, False otherwise. + """ + # Check tf is installed + if importlib.util.find_spec("tensorflow") is None: + logger.warning("Tensorflow not installed, cannot initialize profiling plugin") + return False + + # Check tensorflow version + version = _get_tf_versioning() + if version is None: + logger.warning( + "Could not find major, minor, and patch versions of tensorflow. Version found: %s", + version, + ) + return False + + # Check compatibility, introduced in tensorflow >= 2.2.0 + if not _is_compatible_version(version): + logger.warning( + "Version %s is incompatible with tf profiler." + "To use the profiler, choose a version >= 2.2.0", + "%s.%s.%s" % (version.major, version.minor, version.patch), + ) + return False + + # Check for the tf profiler plugin + if importlib.util.find_spec("tensorboard_plugin_profile") is None: + logger.warning( + "Could not import tensorboard_plugin_profile, will not run tf profiling service" + ) + return False + + return True + + +def _create_profiling_context() -> tensorboard_base_plugin.TBContext: + """Creates the base context needed for TB Profiler. + + Returns: + An initialized `TBContext`. + """ + + context_flags = argparse.Namespace(master_tpu_unsecure_channel=None) + + context = tensorboard_base_plugin.TBContext( + logdir=environment_variables.tensorboard_log_dir, + multiplexer=None, + flags=context_flags, + ) + + return context + + +def _host_to_grpc(hostname: str) -> str: + """Format a hostname to a grpc address. + + Args: + hostname (str): + Required. Address in form: `{hostname}:{port}` + + Returns: + Address in form of: 'grpc://{hostname}:{port}' + """ + return ( + "grpc://" + + "".join(hostname.split(":")[:-1]) + + ":" + + environment_variables.tf_profiler_port + ) + + +def _get_hostnames() -> Optional[str]: + """Get the hostnames for all servers running. + + Returns: + A host formatted by `_host_to_grpc` if obtaining the cluster spec + is successful, None otherwise. + """ + cluster_spec = environment_variables.cluster_spec + if cluster_spec is None: + return + + cluster = cluster_spec.get("cluster", "") + if not cluster: + return + + hostnames = [] + for value in cluster.values(): + hostnames.extend(value) + + return ",".join([_host_to_grpc(x) for x in hostnames]) + + +def _update_environ(environ: wsgi_types.Environment) -> bool: + """Add parameters to the query that are retrieved from training side. + + Args: + environ (wsgi_types.Environment): + Required. The WSGI Environment. + + Returns: + Whether the environment was successfully updated. + """ + hosts = _get_hostnames() + + if hosts is None: + return False + + query_dict = {} + query_dict["service_addr"] = hosts + + # Update service address and worker list + # Use parse_qsl and then convert list to dictionary so we can update + # attributes + prev_query_string = dict(parse.parse_qsl(environ["QUERY_STRING"])) + prev_query_string.update(query_dict) + + environ["QUERY_STRING"] = parse.urlencode(prev_query_string) + + return True + + +def warn_tensorboard_env_var(var_name: str): + """Warns if a tensorboard related environment variable is missing. + + Args: + var_name (str): + Required. The name of the missing environment variable. + """ + logging.warning( + f"Environment variable `{var_name}` must be set. " + _BASE_TB_ENV_WARNING + ) + + +def _check_env_vars() -> bool: + """Determine whether the correct environment variables are set. + + Returns: + bool indicating all necessary variables are set. + """ + # The below are tensorboard specific environment variables. + if environment_variables.tf_profiler_port is None: + warn_tensorboard_env_var("AIP_TF_PROFILER_PORT") + return False + + if environment_variables.tensorboard_log_dir is None: + warn_tensorboard_env_var("AIP_TENSORBOARD_LOG_DIR") + return False + + if environment_variables.tensorboard_api_uri is None: + warn_tensorboard_env_var("AIP_TENSORBOARD_API_URI") + return False + + if environment_variables.tensorboard_resource_name is None: + warn_tensorboard_env_var("AIP_TENSORBOARD_RESOURCE_NAME") + return False + + # These environment variables are not tensorboard related, they are + # variables set for any Vertex training run. + cluster_spec = environment_variables.cluster_spec + if cluster_spec is None: + logger.warning("Environment variable `CLUSTER_SPEC` is not set.") + return False + + if environment_variables.cloud_ml_job_id is None: + logger.warning("Environment variable `CLOUD_ML_JOB_ID` is not set") + return False + + return True + + +class TFProfiler(base_plugin.BasePlugin): + """Handler for Tensorflow Profiling.""" + + PLUGIN_NAME = "profile" + + def __init__(self): + """Build a TFProfiler object.""" + from tensorboard_plugin_profile.profile_plugin import ProfilePlugin + + context = _create_profiling_context() + self._profile_request_sender: profile_uploader.ProfileRequestSender = tensorboard_api.create_profile_request_sender() + self._profile_plugin: ProfilePlugin = ProfilePlugin(context) + + def get_routes( + self, + ) -> Dict[str, Callable[[Dict[str, str], Callable[..., None]], Response]]: + """List of routes to serve. + + Returns: + A callable that takes an werkzeug env and start response and returns a response. + """ + return {"/capture_profile": self.capture_profile_wrapper} + + # Define routes below + def capture_profile_wrapper( + self, environ: wsgi_types.Environment, start_response: wsgi_types.StartResponse + ) -> Response: + """Take a request from tensorboard.gcp and run the profiling for the available servers. + + Args: + environ (wsgi_types.Environment): + Required. The WSGI environment. + start_response (wsgi_types.StartResponse): + Required. The response callable provided by the WSGI server. + + Returns: + A response iterable. + """ + # The service address (localhost) and worker list are populated locally + if not _update_environ(environ): + err = {"error": "Could not parse the environ: %s"} + return Response( + json.dumps(err), content_type="application/json", status=500 + ) + + response = self._profile_plugin.capture_route(environ, start_response) + + self._profile_request_sender.send_request("") + + return response + + # End routes + + @staticmethod + def setup() -> None: + """Sets up the plugin. + + Raises: + ImportError: Tensorflow could not be imported. + """ + try: + import tensorflow as tf + except ImportError as err: + raise ImportError( + "Could not import tensorflow for profile usage. " + "To use profiler, install the SDK using " + '"pip install google-cloud-aiplatform[cloud_profiler]"' + ) from err + + tf.profiler.experimental.server.start( + int(environment_variables.tf_profiler_port) + ) + + @staticmethod + def post_setup_check() -> bool: + """Only chief and task 0 should run the webserver.""" + cluster_spec = environment_variables.cluster_spec + task_type = cluster_spec.get("task", {}).get("type", "") + task_index = cluster_spec.get("task", {}).get("index", -1) + + return task_type in {"workerpool0", "chief"} and task_index == 0 + + @staticmethod + def can_initialize() -> bool: + """Check that we can use the TF Profiler plugin. + + This function checks a number of dependencies for the plugin to ensure we have the + right packages installed, the necessary versions, and the correct environment variables set. + + Returns: + True if can initialize, False otherwise. + """ + + return _check_env_vars() and _check_tf() diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/webserver.py b/google/cloud/aiplatform/training_utils/cloud_profiler/webserver.py new file mode 100644 index 0000000000..3f7706bb34 --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/webserver.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A basic webserver for hosting plugin routes.""" + +import os + +from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin +from typing import List +from werkzeug import wrappers, Response + + +class WebServer: + """A basic web server for handling requests.""" + + def __init__(self, plugins: List[base_plugin.BasePlugin]): + """Creates a web server to host plugin routes. + + Args: + plugins (List[base_plugin.BasePlugin]): + Required. A list of `BasePlugin` objects. + + Raises: + ValueError: + When there is an invalid route passed from + one of the plugins. + """ + + self._plugins = plugins + self._routes = {} + + # Routes are in form {plugin_name}/{route} + for plugin in self._plugins: + for route, handler in plugin.get_routes().items(): + if not route.startswith("/"): + raise ValueError( + 'Routes should start with a "/", ' + "invalid route for plugin %s, route %s" + % (plugin.PLUGIN_NAME, route) + ) + + app_route = os.path.join("/", plugin.PLUGIN_NAME) + + app_route += route + self._routes[app_route] = handler + + def dispatch_request( + self, environ: wsgi_types.Environment, start_response: wsgi_types.StartResponse + ) -> Response: + """Handles the routing of requests. + + Args: + environ (wsgi_types.Environment): + Required. The WSGI environment. + start_response (wsgi_types.StartResponse): + Required. The response callable provided by the WSGI server. + + Returns: + A response iterable. + """ + # Check for existince of route + request = wrappers.Request(environ) + + if request.path in self._routes: + return self._routes[request.path](environ, start_response) + + response = wrappers.Response("Not Found", status=404) + return response(environ, start_response) + + def wsgi_app( + self, environ: wsgi_types.Environment, start_response: wsgi_types.StartResponse + ) -> Response: + """Entrypoint for wsgi application. + + Args: + environ (wsgi_types.Environment): + Required. The WSGI environment. + start_response (wsgi_types.StartResponse): + Required. The response callable provided by the WSGI server. + + Returns: + A response iterable. + """ + response = self.dispatch_request(environ, start_response) + return response + + def __call__(self, environ, start_response): + """Entrypoint for wsgi application. + + Args: + environ (wsgi_types.Environment): + Required. The WSGI environment. + start_response (wsgi_types.StartResponse): + Required. The response callable provided by the WSGI server. + + Returns: + A response iterable. + """ + return self.wsgi_app(environ, start_response) diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/wsgi_types.py b/google/cloud/aiplatform/training_utils/cloud_profiler/wsgi_types.py new file mode 100644 index 0000000000..0348c5b91e --- /dev/null +++ b/google/cloud/aiplatform/training_utils/cloud_profiler/wsgi_types.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Typing description for the WSGI App callables +# For more information on WSGI, see PEP 3333 + +from typing import Any, Dict, Text, Callable + +# Contain CGI environment variables, as defined by the Common Gateway Interface +# specification. +Environment = Dict[Text, Any] + +# Used to begin the HTTP response. +StartResponse = Callable[..., Callable[[bytes], None]] diff --git a/google/cloud/aiplatform/training_utils/environment_variables.py b/google/cloud/aiplatform/training_utils/environment_variables.py index 2771c0746c..0783e78251 100644 --- a/google/cloud/aiplatform/training_utils/environment_variables.py +++ b/google/cloud/aiplatform/training_utils/environment_variables.py @@ -15,6 +15,8 @@ # limitations under the License. # +# Environment variables used in Vertex AI Training. + import json import os @@ -74,3 +76,6 @@ def _json_helper(env_var: str) -> Optional[Dict]: # The name given to the training job. cloud_ml_job_id = os.environ.get("CLOUD_ML_JOB_ID") + +# The HTTP Handler port to use to host the profiling webserver. +http_handler_port = os.environ.get("AIP_HTTP_HANDLER_PORT") diff --git a/setup.py b/setup.py index caa50df32c..d0daa259f3 100644 --- a/setup.py +++ b/setup.py @@ -36,10 +36,14 @@ tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"] metadata_extra_require = ["pandas >= 1.0.0"] xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"] +profiler_extra_require = ["tensorboard-plugin-profile", "tensorflow >=2.4.0"] + full_extra_require = list( set(tensorboard_extra_require + metadata_extra_require + xai_extra_require) ) -testing_extra_require = full_extra_require + ["grpcio-testing", "pytest-xdist"] +testing_extra_require = ( + full_extra_require + profiler_extra_require + ["grpcio-testing", "pytest-xdist"] +) setuptools.setup( @@ -80,6 +84,7 @@ "tensorboard": tensorboard_extra_require, "testing": testing_extra_require, "xai": xai_extra_require, + "cloud_profiler": profiler_extra_require, }, python_requires=">=3.6", scripts=[], diff --git a/tests/unit/aiplatform/test_cloud_profiler.py b/tests/unit/aiplatform/test_cloud_profiler.py new file mode 100644 index 0000000000..e540279bf9 --- /dev/null +++ b/tests/unit/aiplatform/test_cloud_profiler.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +import importlib.util +import json +import threading +from typing import List, Optional + +import pytest +import unittest + +from unittest import mock +from werkzeug import wrappers +from werkzeug.test import EnvironBuilder + +from google.api_core import exceptions +from google.cloud import aiplatform +from google.cloud.aiplatform import training_utils +from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( + tf_profiler, +) +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow.tf_profiler import ( + TFProfiler, +) +from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( + tensorboard_api, +) +from google.cloud.aiplatform.training_utils.cloud_profiler import webserver +from google.cloud.aiplatform.training_utils.cloud_profiler import initializer + + +# Mock cluster specs from the training environment. +_CLUSTER_SPEC_VM = { + "cluster": {"chief": ["localhost:1234"]}, + "environment": "cloud", + "task": {"type": "chief", "index": 0}, +} + + +def _create_mock_plugin( + plugin_name: str = "test_plugin", routes: Optional[List] = ["/route1"] +): + mock_plugin = mock.Mock(spec=base_plugin.BasePlugin) + mock_plugin.can_initialize.return_value = True + mock_plugin.post_setup_check.return_value = True + mock_plugin.PLUGIN_NAME = plugin_name + + # Some mock routes to test number of times each has been called. + mock_routes = {} + for route in routes: + mock_routes[route] = mock.Mock() + + mock_plugin.get_routes.return_value = mock_routes + + # A call should just return the mock object itself. + mock_plugin.return_value = mock_plugin + + return mock_plugin + + +@pytest.fixture +def tf_profile_plugin_mock(): + """Mock the tensorboard profile plugin""" + import tensorboard_plugin_profile.profile_plugin + + with mock.patch.object( + tensorboard_plugin_profile.profile_plugin.ProfilePlugin, "capture_route" + ) as profile_mock: + profile_mock.return_value = ( + wrappers.BaseResponse( + json.dumps({"error": "some error"}), + content_type="application/json", + status=200, + ), + ) + yield profile_mock + + +@pytest.fixture +def tensorboard_api_mock(): + with mock.patch.object( + tensorboard_api, "create_profile_request_sender", + ) as sender_mock: + sender_mock.return_value = mock.Mock() + yield sender_mock + + +@pytest.fixture +def mock_api_environment_variables(): + with mock.patch.object(training_utils, "environment_variables") as mock_env: + mock_env.tensorboard_api_uri = "testuri" + mock_env.tensorboard_resource_name = ( + "projects/testproj/locations/us-central1/tensorboards/123" + ) + mock_env.cloud_ml_job_id = "test_job_id" + mock_env.tensorboard_log_dir = "gs://my_log_dir" + + yield mock_env + + +def setupProfilerEnvVars(): + tf_profiler.environment_variables.tf_profiler_port = "6009" + tf_profiler.environment_variables.tensorboard_log_dir = "tmp/" + tf_profiler.environment_variables.tensorboard_api_uri = "test_api_uri" + tf_profiler.environment_variables.tensorboard_resource_name = ( + "projects/123/region/us-central1/tensorboards/mytb" + ) + tf_profiler.environment_variables.cluster_spec = _CLUSTER_SPEC_VM + tf_profiler.environment_variables.cloud_ml_job_id = "myjob" + + +class TestProfilerPlugin(unittest.TestCase): + def setUp(self): + setupProfilerEnvVars() + + # Environment variable tests + def testCanInitializeProfilerPortUnset(self): + tf_profiler.environment_variables.tf_profiler_port = None + assert not TFProfiler.can_initialize() + + def testCanInitializeTBLogDirUnset(self): + tf_profiler.environment_variables.tensorboard_log_dir = None + assert not TFProfiler.can_initialize() + + def testCanInitializeTBAPIuriUnset(self): + tf_profiler.environment_variables.tensorboard_api_uri = None + assert not TFProfiler.can_initialize() + + def testCanInitializeTBResourceNameUnset(self): + tf_profiler.environment_variables.tensorboard_resource_name = None + assert not TFProfiler.can_initialize() + + def testCanInitializeJobIdUnset(self): + tf_profiler.environment_variables.cloud_ml_job_id = None + assert not TFProfiler.can_initialize() + + def testCanInitializeNoClusterSpec(self): + tf_profiler.environment_variables.cluster_spec = None + assert not TFProfiler.can_initialize() + + # Check tensorflow dependencies + def testCanInitializeTFInstalled(self): + orig_find_spec = importlib.util.find_spec + + def tf_import_mock(name, *args, **kwargs): + if name == "tensorflow": + return None + return orig_find_spec(name, *args, **kwargs) + + with mock.patch("importlib.util.find_spec", side_effect=tf_import_mock): + assert not TFProfiler.can_initialize() + + def testCanInitializeTFVersion(self): + import tensorflow + + with mock.patch.dict(tensorflow.__dict__, {"__version__": "1.2.3.4"}): + assert not TFProfiler.can_initialize() + + def testCanInitializeOldTFVersion(self): + import tensorflow + + with mock.patch.dict(tensorflow.__dict__, {"__version__": "2.3.0"}): + assert not TFProfiler.can_initialize() + + def testCanInitializeNoProfilePlugin(self): + orig_find_spec = importlib.util.find_spec + + def plugin_import_mock(name, *args, **kwargs): + if name == "tensorboard_plugin_profile": + return None + return orig_find_spec(name, *args, **kwargs) + + with mock.patch("importlib.util.find_spec", side_effect=plugin_import_mock): + assert not TFProfiler.can_initialize() + + def testCanInitialize(self): + assert TFProfiler.can_initialize() + + def testSetup(self): + import tensorflow + + with mock.patch.object( + tensorflow.profiler.experimental.server, "start", return_value=None + ) as server_mock: + TFProfiler.setup() + + assert server_mock.call_count == 1 + + def testSetupRaiseImportError(self): + with mock.patch.dict("sys.modules", {"tensorflow": None}): + self.assertRaises(ImportError, TFProfiler.setup) + + def testPostSetupChecksFail(self): + tf_profiler.environment_variables.cluster_spec = {} + assert not TFProfiler.post_setup_check() + + def testPostSetupChecks(self): + assert TFProfiler.post_setup_check() + + # Tests for plugin + @pytest.mark.usefixtures("tf_profile_plugin_mock") + @pytest.mark.usefixtures("tensorboard_api_mock") + def testCaptureProfile(self): + profiler = TFProfiler() + environ = dict(QUERY_STRING="?service_addr=myhost1,myhost2&someotherdata=5") + start_response = None + + resp = profiler.capture_profile_wrapper(environ, start_response) + assert resp[0].status_code == 200 + + @pytest.mark.usefixtures("tf_profile_plugin_mock") + @pytest.mark.usefixtures("tensorboard_api_mock") + def testCaptureProfileNoClusterSpec(self): + profiler = TFProfiler() + + environ = dict(QUERY_STRING="?service_addr=myhost1,myhost2&someotherdata=5") + start_response = None + + tf_profiler.environment_variables.cluster_spec = None + resp = profiler.capture_profile_wrapper(environ, start_response) + + assert resp.status_code == 500 + + @pytest.mark.usefixtures("tf_profile_plugin_mock") + @pytest.mark.usefixtures("tensorboard_api_mock") + def testCaptureProfileNoCluster(self): + profiler = TFProfiler() + + environ = dict(QUERY_STRING="?service_addr=myhost1,myhost2&someotherdata=5") + start_response = None + tf_profiler.environment_variables.cluster_spec = {"cluster": {}} + + resp = profiler.capture_profile_wrapper(environ, start_response) + + assert resp.status_code == 500 + + @pytest.mark.usefixtures("tf_profile_plugin_mock") + @pytest.mark.usefixtures("tensorboard_api_mock") + def testGetRoutes(self): + profiler = TFProfiler() + + routes = profiler.get_routes() + assert isinstance(routes, dict) + + +# Tensorboard API tests +class TestTensorboardAPIBuilder(unittest.TestCase): + @pytest.mark.usefixtures("mock_api_environment_variables") + def test_get_api_client(self): + with mock.patch.object(aiplatform, "initializer") as mock_initializer: + tensorboard_api._get_api_client() + mock_initializer.global_config.create_client.assert_called_once() + + def test_get_project_id_fail(self): + with mock.patch.object(training_utils, "environment_variables") as mock_env: + mock_env.tensorboard_resource_name = "bad_resource" + self.assertRaises(ValueError, tensorboard_api._get_project_id) + + @pytest.mark.usefixtures("mock_api_environment_variables") + def test_get_project_id(self): + project_id = tensorboard_api._get_project_id() + assert project_id == "testproj" + + @pytest.mark.usefixtures("mock_api_environment_variables") + def test_get_or_create_experiment(self): + api = mock.Mock() + api.create_tensorboard_experiment.side_effect = exceptions.AlreadyExists("test") + tensorboard_api._get_or_create_experiment(api, "test") + api.get_tensorboard_experiment.assert_called_once() + + @pytest.mark.usefixtures("mock_api_environment_variables") + def test_create_profile_request_sender(self): + tensorboard_api.storage = mock.Mock() + tensorboard_api.uploader_utils = mock.Mock() + + with mock.patch.object(profile_uploader, "ProfileRequestSender") as mock_sender: + with mock.patch.object(aiplatform, "initializer"): + tensorboard_api.create_profile_request_sender() + mock_sender.assert_called_once() + + +# Webserver tests +class TestWebServer(unittest.TestCase): + def test_create_webserver_bad_route(self): + plugin = _create_mock_plugin() + plugin.get_routes.return_value = {"my_route": "some_handler"} + + self.assertRaises(ValueError, webserver.WebServer, [plugin]) + + def test_dispatch_bad_request(self): + plugin = _create_mock_plugin() + plugin.get_routes.return_value = {"/test_route": "test_handler"} + + ws = webserver.WebServer([plugin]) + + builder = EnvironBuilder(method="GET", path="/") + + env = builder.get_environ() + + # Mock a start response callable + response = [] + buff = [] + + def start_response(status, headers): + response[:] = [status, headers] + return buff.append + + ws(env, start_response) + + assert response[0] == "404 NOT FOUND" + + def test_correct_response(self): + res_dict = {"response": "OK"} + + def my_callable(var1, var2): + return wrappers.BaseResponse( + json.dumps(res_dict), content_type="application/json", status=200 + ) + + plugin = _create_mock_plugin() + plugin.get_routes.return_value = {"/my_route": my_callable} + ws = webserver.WebServer([plugin]) + + builder = EnvironBuilder(method="GET", path="/test_plugin/my_route") + + env = builder.get_environ() + + # Mock a start response callable + response = [] + buff = [] + + def start_response(status, headers): + response[:] = [status, headers] + return buff.append + + res = ws(env, start_response) + + final_response = json.loads(res.response[0].decode("utf-8")) + + assert final_response == res_dict + + +# Initializer tests +class TestInitializer(unittest.TestCase): + # Tests for building the plugin + def test_init_failed_import(self): + with mock.patch.dict( + "sys.modules", + {"google.cloud.aiplatform.training_utils.cloud_profiler.initializer": None}, + ): + self.assertRaises(ImportError, reload, training_utils.cloud_profiler) + + def test_build_plugin_fail_initialize(self): + plugin = _create_mock_plugin() + plugin.can_initialize.return_value = False + + assert not initializer._build_plugin(plugin) + + def test_build_plugin_fail_setup_check(self): + plugin = _create_mock_plugin() + plugin.can_initialize.return_value = True + plugin.post_setup_check.return_value = False + + assert not initializer._build_plugin(plugin) + + def test_build_plugin_success(self): + plugin = _create_mock_plugin() + plugin.can_initialize.return_value = True + plugin.post_setup_check.return_value = True + + initializer._build_plugin(plugin) + + assert plugin.called + + # Testing the initialize function + def test_initialize_bad_plugin(self): + with mock.patch.object(initializer, "_AVAILABLE_PLUGINS", {}): + self.assertRaises(ValueError, initializer.initialize, "bad_plugin") + + def test_initialize_build_plugin_fail(self): + plugin = _create_mock_plugin() + with mock.patch.object(initializer, "_AVAILABLE_PLUGINS", {"test": plugin}): + with mock.patch.object(initializer, "_build_plugin") as build_mock: + with mock.patch.object( + initializer, "_run_app_thread" + ) as app_thread_mock: + build_mock.return_value = None + initializer.initialize("test") + + assert not app_thread_mock.call_count + + def test_initialize_no_http_handler(self): + plugin = _create_mock_plugin() + initializer.environment_variables.http_handler_port = None + + with mock.patch.object(initializer, "_AVAILABLE_PLUGINS", {"test": plugin}): + with pytest.raises(initializer.MissingEnvironmentVariableException): + initializer.initialize("test") + + def test_initialize_build_plugin_success(self): + plugin = _create_mock_plugin() + initializer.environment_variables.http_handler_port = "1234" + + with mock.patch.object(initializer, "_AVAILABLE_PLUGINS", {"test": plugin}): + with mock.patch.object(initializer, "_run_app_thread") as app_thread_mock: + initializer.initialize("test") + + assert app_thread_mock.call_count == 1 + + def test_run_app_thread(self): + with mock.patch.object(threading, "Thread") as mock_thread: + daemon_mock = mock.Mock() + mock_thread.return_value = daemon_mock + + initializer._run_app_thread(None, 1234) + + assert daemon_mock.start.call_count == 1 diff --git a/tests/unit/aiplatform/test_training_utils.py b/tests/unit/aiplatform/test_training_utils.py index 626d5c1d2a..99d49b7ead 100644 --- a/tests/unit/aiplatform/test_training_utils.py +++ b/tests/unit/aiplatform/test_training_utils.py @@ -60,6 +60,7 @@ "projects/myproj/locations/us-central1/tensorboards/1234" ) _TEST_CLOUD_ML_JOB_ID = "myjob" +_TEST_AIP_HTTP_HANDLER_PORT = "5678" class TestTrainingUtils: @@ -73,6 +74,7 @@ def mock_environment(self): "AIP_CHECKPOINT_DIR": _TEST_CHECKPOINT_DIR, "AIP_TENSORBOARD_LOG_DIR": _TEST_TENSORBOARD_LOG_DIR, "AIP_TF_PROFILER_PORT": _TEST_AIP_TF_PROFILER_PORT, + "AIP_HTTP_HANDLER_PORT": _TEST_AIP_HTTP_HANDLER_PORT, "AIP_TENSORBOARD_API_URI": _TEST_TENSORBOARD_API_URI, "AIP_TENSORBOARD_RESOURCE_NAME": _TEST_TENSORBOARD_RESOURCE_NAME, "CLOUD_ML_JOB_ID": _TEST_CLOUD_ML_JOB_ID, @@ -192,3 +194,12 @@ def test_cloud_ml_job_id(self): def test_cloud_ml_job_id_none(self): reload(environment_variables) assert environment_variables.cloud_ml_job_id is None + + @pytest.mark.usefixtures("mock_environment") + def test_http_handler_port(self): + reload(environment_variables) + assert environment_variables.http_handler_port == _TEST_AIP_HTTP_HANDLER_PORT + + def test_http_handler_port_none(self): + reload(environment_variables) + assert environment_variables.http_handler_port is None