-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat--Support-uploading-local-models
- Loading branch information
Showing
16 changed files
with
1,414 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 20 additions & 0 deletions
20
google/cloud/aiplatform/training_utils/cloud_profiler/README.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Cloud Profiler | ||
================================= | ||
|
||
Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex Tensorboard. | ||
|
||
Quick Start | ||
------------ | ||
|
||
To start using the profiler with TensorFlow, update your training script to include the following: | ||
|
||
.. code-block:: Python | ||
from google.cloud.aiplatform.training_utils import cloud_profiler | ||
... | ||
cloud_profiler.init() | ||
Next, run the job with with a Vertex TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview | ||
|
||
Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs. |
35 changes: 35 additions & 0 deletions
35
google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
try: | ||
import google.cloud.aiplatform.training_utils.cloud_profiler.initializer as initializer | ||
except ImportError as err: | ||
raise ImportError( | ||
"Could not load the cloud profiler. To use the profiler, " | ||
'install the SDK using "pip install google-cloud-aiplatform[cloud-profiler]"' | ||
) from err | ||
|
||
""" | ||
Initialize the cloud profiler for tensorflow. | ||
Usage: | ||
from google.cloud.aiplatform.training_utils import cloud_profiler | ||
cloud_profiler.init(profiler='tensorflow') | ||
""" | ||
|
||
init = initializer.initialize |
118 changes: 118 additions & 0 deletions
118
google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import logging | ||
import threading | ||
from typing import Optional, Type | ||
from werkzeug import serving | ||
|
||
from google.cloud.aiplatform.training_utils import environment_variables | ||
from google.cloud.aiplatform.training_utils.cloud_profiler import webserver | ||
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin | ||
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import ( | ||
tf_profiler, | ||
) | ||
|
||
# Mapping of available plugins to use | ||
_AVAILABLE_PLUGINS = {"tensorflow": tf_profiler.TFProfiler} | ||
|
||
|
||
class MissingEnvironmentVariableException(Exception): | ||
pass | ||
|
||
|
||
def _build_plugin( | ||
plugin: Type[base_plugin.BasePlugin], | ||
) -> Optional[base_plugin.BasePlugin]: | ||
"""Builds the plugin given the object. | ||
Args: | ||
plugin (Type[base_plugin]): | ||
Required. An uninitialized plugin class. | ||
Returns: | ||
An initialized plugin, or None if plugin cannot be | ||
initialized. | ||
""" | ||
if not plugin.can_initialize(): | ||
logging.warning("Cannot initialize the plugin") | ||
return | ||
|
||
plugin.setup() | ||
|
||
if not plugin.post_setup_check(): | ||
return | ||
|
||
return plugin() | ||
|
||
|
||
def _run_app_thread(server: webserver.WebServer, port: int): | ||
"""Run the webserver in a separate thread. | ||
Args: | ||
server (webserver.WebServer): | ||
Required. A webserver to accept requests. | ||
port (int): | ||
Required. The port to run the webserver on. | ||
""" | ||
daemon = threading.Thread( | ||
name="profile_server", | ||
target=serving.run_simple, | ||
args=("0.0.0.0", port, server,), | ||
) | ||
daemon.setDaemon(True) | ||
daemon.start() | ||
|
||
|
||
def initialize(plugin: str = "tensorflow"): | ||
"""Initializes the profiling SDK. | ||
Args: | ||
plugin (str): | ||
Required. Name of the plugin to initialize. | ||
Current options are ["tensorflow"] | ||
Raises: | ||
ValueError: | ||
The plugin does not exist. | ||
MissingEnvironmentVariableException: | ||
An environment variable that is needed is not set. | ||
""" | ||
plugin_obj = _AVAILABLE_PLUGINS.get(plugin) | ||
|
||
if not plugin_obj: | ||
raise ValueError( | ||
"Plugin {} not available, must choose from {}".format( | ||
plugin, _AVAILABLE_PLUGINS.keys() | ||
) | ||
) | ||
|
||
prof_plugin = _build_plugin(plugin_obj) | ||
|
||
if prof_plugin is None: | ||
return | ||
|
||
server = webserver.WebServer([prof_plugin]) | ||
|
||
if not environment_variables.http_handler_port: | ||
raise MissingEnvironmentVariableException( | ||
"'AIP_HTTP_HANDLER_PORT' must be set." | ||
) | ||
|
||
port = int(environment_variables.http_handler_port) | ||
|
||
_run_app_thread(server, port) |
71 changes: 71 additions & 0 deletions
71
google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import abc | ||
from typing import Callable, Dict | ||
from werkzeug import Response | ||
|
||
|
||
class BasePlugin(abc.ABC): | ||
"""Base plugin for cloud training tools endpoints. | ||
The plugins support registering http handlers to be used for | ||
AI Platform training jobs. | ||
""" | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def setup() -> None: | ||
"""Run any setup code for the plugin before webserver is launched.""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def can_initialize() -> bool: | ||
"""Check whether a plugin is able to be initialized. | ||
Used for checking if correct dependencies are installed, system requirements, etc. | ||
Returns: | ||
Bool indicating whether the plugin can be initialized. | ||
""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def post_setup_check() -> bool: | ||
"""Check if after initialization, we need to use the plugin. | ||
Example: Web server only needs to run for main node for training, others | ||
just need to have 'setup()' run to start the rpc server. | ||
Returns: | ||
A boolean indicating whether post setup checks pass. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def get_routes(self) -> Dict[str, Callable[..., Response]]: | ||
"""Get the mapping from path to handler. | ||
This is the method in which plugins can assign different routes to | ||
different handlers. | ||
Returns: | ||
A mapping from a route to a handler. | ||
""" | ||
raise NotImplementedError |
Oops, something went wrong.