diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 58eb824454..e56e57a2ad 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -23,6 +23,7 @@ ImageDataset, TabularDataset, TextDataset, + TimeSeriesDataset, VideoDataset, ) from google.cloud.aiplatform.models import Endpoint @@ -33,10 +34,12 @@ CustomContainerTrainingJob, CustomPythonPackageTrainingJob, AutoMLTabularTrainingJob, + AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTextTrainingJob, AutoMLVideoTrainingJob, ) +from google.cloud.aiplatform.metadata import metadata """ Usage: @@ -46,12 +49,25 @@ """ init = initializer.global_config.init +log_params = metadata.metadata_service.log_params +log_metrics = metadata.metadata_service.log_metrics +get_experiment_df = metadata.metadata_service.get_experiment_df +get_pipeline_df = metadata.metadata_service.get_pipeline_df +start_run = metadata.metadata_service.start_run + + __all__ = ( "explain", "gapic", "init", + "log_params", + "log_metrics", + "get_experiment_df", + "get_pipeline_df", + "start_run", "AutoMLImageTrainingJob", "AutoMLTabularTrainingJob", + "AutoMLForecastingTrainingJob", "AutoMLTextTrainingJob", "AutoMLVideoTrainingJob", "BatchPredictionJob", @@ -63,5 +79,6 @@ "Model", "TabularDataset", "TextDataset", + "TimeSeriesDataset", "VideoDataset", ) diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index 36d805c6cb..980c554fe1 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -34,6 +34,8 @@ services.specialist_pool_service_client = ( services.specialist_pool_service_client_v1beta1 ) + services.metadata_service_client = services.metadata_service_client_v1beta1 + services.tensorboard_service_client = services.tensorboard_service_client_v1beta1 types.accelerator_type = types.accelerator_type_v1beta1 types.annotation = types.annotation_v1beta1 @@ -69,6 +71,13 @@ types.specialist_pool = types.specialist_pool_v1beta1 types.specialist_pool_service = types.specialist_pool_service_v1beta1 types.training_pipeline = types.training_pipeline_v1beta1 + types.metadata_service = types.metadata_service_v1beta1 + types.tensorboard_service = types.tensorboard_service_v1beta1 + types.tensorboard_data = types.tensorboard_data_v1beta1 + types.tensorboard_experiment = types.tensorboard_experiment_v1beta1 + types.tensorboard_run = types.tensorboard_run_v1beta1 + types.tensorboard_service = types.tensorboard_service_v1beta1 + types.tensorboard_time_series = types.tensorboard_time_series_v1beta1 if DEFAULT_VERSION == V1: diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py index 0888c27fbb..5c104ab41f 100644 --- a/google/cloud/aiplatform/compat/services/__init__.py +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -36,6 +36,12 @@ from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( client as specialist_pool_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.metadata_service import ( + client as metadata_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.tensorboard_service import ( + client as tensorboard_service_client_v1beta1, +) from google.cloud.aiplatform_v1.services.dataset_service import ( client as dataset_service_client_v1, @@ -76,4 +82,6 @@ pipeline_service_client_v1beta1, prediction_service_client_v1beta1, specialist_pool_service_client_v1beta1, + metadata_service_client_v1beta1, + tensorboard_service_client_v1beta1, ) diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index d03e0d2f3a..f45bb2e11e 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -50,6 +50,13 @@ specialist_pool as specialist_pool_v1beta1, specialist_pool_service as specialist_pool_service_v1beta1, training_pipeline as training_pipeline_v1beta1, + metadata_service as metadata_service_v1beta1, + tensorboard_service as tensorboard_service_v1beta1, + tensorboard_data as tensorboard_data_v1beta1, + tensorboard_experiment as tensorboard_experiment_v1beta1, + tensorboard_run as tensorboard_run_v1beta1, + tensorboard_service as tensorboard_service_v1beta1, + tensorboard_time_series as tensorboard_time_series_v1beta1, ) from google.cloud.aiplatform_v1.types import ( accelerator_type as accelerator_type_v1, @@ -155,4 +162,11 @@ specialist_pool_v1beta1, specialist_pool_service_v1beta1, training_pipeline_v1beta1, + metadata_service_v1beta1, + tensorboard_service_v1beta1, + tensorboard_data_v1beta1, + tensorboard_experiment_v1beta1, + tensorboard_run_v1beta1, + tensorboard_service_v1beta1, + tensorboard_time_series_v1beta1, ) diff --git a/google/cloud/aiplatform/datasets/__init__.py b/google/cloud/aiplatform/datasets/__init__.py index 57e2bad45d..b297530955 100644 --- a/google/cloud/aiplatform/datasets/__init__.py +++ b/google/cloud/aiplatform/datasets/__init__.py @@ -17,6 +17,7 @@ from google.cloud.aiplatform.datasets.dataset import _Dataset from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset +from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset from google.cloud.aiplatform.datasets.image_dataset import ImageDataset from google.cloud.aiplatform.datasets.text_dataset import TextDataset from google.cloud.aiplatform.datasets.video_dataset import VideoDataset @@ -25,6 +26,7 @@ __all__ = ( "_Dataset", "TabularDataset", + "TimeSeriesDataset", "ImageDataset", "TextDataset", "VideoDataset", diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py index a01e68c01f..ea436eb91b 100644 --- a/google/cloud/aiplatform/datasets/_datasources.py +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -225,6 +225,11 @@ def create_datasource( raise ValueError("tabular dataset does not support data import.") return TabularDatasource(gcs_source, bq_source) + if metadata_schema_uri == schema.dataset.metadata.time_series: + if import_schema_uri: + raise ValueError("time series dataset does not support data import.") + return TabularDatasource(gcs_source, bq_source) + if not import_schema_uri and not gcs_source: return NonTabularDatasource() elif import_schema_uri and gcs_source: diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 4bb98cbd77..44dadc4ee4 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -162,7 +162,7 @@ def create( if their content bytes are identical (e.g. image bytes or pdf bytes). These labels will be overridden by Annotation labels specified inside index file refenced by - [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + ``import_schema_uri``, e.g. jsonl file. project (str): Project to upload this model to. Overrides project set in @@ -449,7 +449,7 @@ def import_data( if their content bytes are identical (e.g. image bytes or pdf bytes). These labels will be overridden by Annotation labels specified inside index file refenced by - [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + ``import_schema_uri``, e.g. jsonl file. sync (bool): Whether to execute this method synchronously. If False, this method diff --git a/google/cloud/aiplatform/datasets/image_dataset.py b/google/cloud/aiplatform/datasets/image_dataset.py index fdc6c99a79..c2b3ca68b5 100644 --- a/google/cloud/aiplatform/datasets/image_dataset.py +++ b/google/cloud/aiplatform/datasets/image_dataset.py @@ -82,7 +82,7 @@ def create( if their content bytes are identical (e.g. image bytes or pdf bytes). These labels will be overridden by Annotation labels specified inside index file refenced by - [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + ``import_schema_uri``, e.g. jsonl file. project (str): Project to upload this model to. Overrides project set in diff --git a/google/cloud/aiplatform/datasets/text_dataset.py b/google/cloud/aiplatform/datasets/text_dataset.py index 568edc9e47..6f6fd57bda 100644 --- a/google/cloud/aiplatform/datasets/text_dataset.py +++ b/google/cloud/aiplatform/datasets/text_dataset.py @@ -89,7 +89,7 @@ def create( if their content bytes are identical (e.g. image bytes or pdf bytes). These labels will be overridden by Annotation labels specified inside index file refenced by - [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + ``import_schema_uri``, e.g. jsonl file. project (str): Project to upload this model to. Overrides project set in diff --git a/google/cloud/aiplatform/datasets/time_series_dataset.py b/google/cloud/aiplatform/datasets/time_series_dataset.py new file mode 100644 index 0000000000..92d8e60c37 --- /dev/null +++ b/google/cloud/aiplatform/datasets/time_series_dataset.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class TimeSeriesDataset(datasets._Dataset): + """Managed time series dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.time_series, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "TimeSeriesDataset": + """Creates a new tabular dataset. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + time_series_dataset (TimeSeriesDataset): + Instantiated representation of the managed time series dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.time_series + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + gcs_source=gcs_source, + bq_source=bq_source, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + def import_data(self): + raise NotImplementedError( + f"{self.__class__.__name__} class does not support 'import_data'" + ) diff --git a/google/cloud/aiplatform/datasets/video_dataset.py b/google/cloud/aiplatform/datasets/video_dataset.py index 4115365c64..7064c8b7cf 100644 --- a/google/cloud/aiplatform/datasets/video_dataset.py +++ b/google/cloud/aiplatform/datasets/video_dataset.py @@ -82,7 +82,7 @@ def create( if their content bytes are identical (e.g. image bytes or pdf bytes). These labels will be overridden by Annotation labels specified inside index file refenced by - [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + ``import_schema_uri``, e.g. jsonl file. project (str): Project to upload this model to. Overrides project set in diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 41c698b6d3..9adae3be9a 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -31,6 +31,7 @@ from google.cloud.aiplatform import compat from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata import metadata from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec_compat, @@ -44,7 +45,6 @@ class _Config: def __init__(self): self._project = None - self._experiment = None self._location = None self._staging_bucket = None self._credentials = None @@ -56,21 +56,23 @@ def init( project: Optional[str] = None, location: Optional[str] = None, experiment: Optional[str] = None, + experiment_description: Optional[str] = None, staging_bucket: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, encryption_spec_key_name: Optional[str] = None, ): - """Updates common initalization parameters with provided options. + """Updates common initialization parameters with provided options. Args: project (str): The default project to use when making API calls. location (str): The default location to use when making API calls. If not - set defaults to us-central-1 - experiment (str): The experiment to assign + set defaults to us-central-1. + experiment (str): The experiment name. + experiment_description (str): The description of the experiment. staging_bucket (str): The default staging bucket to use to stage artifacts when making API calls. In the form gs://... - credentials (google.auth.crendentials.Credentials): The default custom - credentials to use when making API calls. If not provided crendentials + credentials (google.auth.credentials.Credentials): The default custom + credentials to use when making API calls. If not provided credentials will be ascertained from the environment. encryption_spec_key_name (Optional[str]): Optional. The Cloud KMS resource identifier of the customer @@ -82,14 +84,27 @@ def init( If set, this resource and all sub-resources will be secured by this key. """ + + # reset metadata_service config if project or location is updated. + if (project and project != self._project) or ( + location and location != self._location + ): + if metadata.metadata_service.experiment_name: + logging.info("project/location updated, reset Metadata config.") + metadata.metadata_service.reset() if project: self._project = project if location: utils.validate_region(location) self._location = location if experiment: - logging.warning("Experiments currently not supported.") - self._experiment = experiment + metadata.metadata_service.set_experiment( + experiment=experiment, description=experiment_description + ) + if experiment_description and experiment is None: + raise ValueError( + "Experiment name needs to be set in `init` in order to add experiment descriptions." + ) if staging_bucket: self._staging_bucket = staging_bucket if credentials: @@ -154,11 +169,6 @@ def location(self) -> str: """Default location.""" return self._location or constants.DEFAULT_REGION - @property - def experiment(self) -> Optional[str]: - """Default experiment, if provided.""" - return self._experiment - @property def staging_bucket(self) -> Optional[str]: """Default staging bucket, if provided.""" diff --git a/google/cloud/aiplatform/metadata/__init__.py b/google/cloud/aiplatform/metadata/__init__.py new file mode 100644 index 0000000000..2144d2e268 --- /dev/null +++ b/google/cloud/aiplatform/metadata/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/aiplatform/metadata/artifact.py b/google/cloud/aiplatform/metadata/artifact.py new file mode 100644 index 0000000000..98eefacc5f --- /dev/null +++ b/google/cloud/aiplatform/metadata/artifact.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Dict + +import proto + +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata.resource import _Resource +from google.cloud.aiplatform_v1beta1 import ListArtifactsRequest +from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact + + +class _Artifact(_Resource): + """Metadata Artifact resource for AI Platform""" + + _resource_noun = "artifacts" + _getter_method = "get_artifact" + + @classmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + gapic_artifact = gca_artifact.Artifact( + schema_title=schema_title, + schema_version=schema_version, + display_name=display_name, + description=description, + metadata=metadata if metadata else {}, + ) + return client.create_artifact( + parent=parent, artifact=gapic_artifact, artifact_id=resource_id, + ) + + @classmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + """Update Artifacts with given input. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + resource (proto.Message): + Required. The proto.Message which contains the update information for the resource. + """ + + return client.update_artifact(artifact=resource) + + @classmethod + def _list_resources( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + filter: Optional[str] = None, + ): + """List artifacts in the parent path that matches the filter. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + parent (str): + Required. The path where Artifacts are stored. + filter (str): + Optional. filter string to restrict the list result + """ + list_request = ListArtifactsRequest(parent=parent, filter=filter,) + return client.list_artifacts(request=list_request) diff --git a/google/cloud/aiplatform/metadata/constants.py b/google/cloud/aiplatform/metadata/constants.py new file mode 100644 index 0000000000..62e7d6e075 --- /dev/null +++ b/google/cloud/aiplatform/metadata/constants.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SYSTEM_RUN = "system.Run" +SYSTEM_EXPERIMENT = "system.Experiment" +SYSTEM_PIPELINE = "system.Pipeline" +SYSTEM_METRICS = "system.Metrics" + +_DEFAULT_SCHEMA_VERSION = "0.0.1" + +SCHEMA_VERSIONS = { + SYSTEM_RUN: _DEFAULT_SCHEMA_VERSION, + SYSTEM_EXPERIMENT: _DEFAULT_SCHEMA_VERSION, + SYSTEM_PIPELINE: _DEFAULT_SCHEMA_VERSION, + SYSTEM_METRICS: _DEFAULT_SCHEMA_VERSION, +} + +# The EXPERIMENT_METADATA is needed until we support context deletion in backend service. +# TODO: delete EXPERIMENT_METADATA once backend supports context deletion. +EXPERIMENT_METADATA = {"experiment_deleted": False} diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py new file mode 100644 index 0000000000..cb3340499b --- /dev/null +++ b/google/cloud/aiplatform/metadata/context.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Dict, Sequence + +import proto + +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata.resource import _Resource +from google.cloud.aiplatform_v1beta1 import ListContextsRequest +from google.cloud.aiplatform_v1beta1.types import context as gca_context + + +class _Context(_Resource): + """Metadata Context resource for AI Platform""" + + _resource_noun = "contexts" + _getter_method = "get_context" + + def add_artifacts_and_executions( + self, + artifact_resource_names: Optional[Sequence[str]] = None, + execution_resource_names: Optional[Sequence[str]] = None, + ): + """Associate Executions and attribute Artifacts to a given Context. + + Args: + artifact_resource_names (Sequence[str]): + Optional. The full resource name of Artifacts to attribute to the Context. + execution_resource_names (Sequence[str]): + Optional. The full resource name of Executions to associate with the Context. + """ + self.api_client.add_context_artifacts_and_executions( + context=self.resource_name, + artifacts=artifact_resource_names, + executions=execution_resource_names, + ) + + @classmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + gapic_context = gca_context.Context( + schema_title=schema_title, + schema_version=schema_version, + display_name=display_name, + description=description, + metadata=metadata if metadata else {}, + ) + return client.create_context( + parent=parent, context=gapic_context, context_id=resource_id, + ) + + @classmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + """Update Contexts with given input. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + resource (proto.Message): + Required. The proto.Message which contains the update information for the resource. + """ + + return client.update_context(context=resource) + + @classmethod + def _list_resources( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + filter: Optional[str] = None, + ): + """List Contexts in the parent path that matches the filter. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + parent (str): + Required. The path where Contexts are stored. + filter (str): + Optional. filter string to restrict the list result + """ + + list_request = ListContextsRequest(parent=parent, filter=filter,) + return client.list_contexts(request=list_request) diff --git a/google/cloud/aiplatform/metadata/execution.py b/google/cloud/aiplatform/metadata/execution.py new file mode 100644 index 0000000000..39fc7a74b3 --- /dev/null +++ b/google/cloud/aiplatform/metadata/execution.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Dict, Sequence + +import proto +from google.api_core import exceptions + +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.metadata.artifact import _Artifact +from google.cloud.aiplatform.metadata.resource import _Resource +from google.cloud.aiplatform_v1beta1 import Event +from google.cloud.aiplatform_v1beta1.types import execution as gca_execution +from google.cloud.aiplatform_v1beta1.types.metadata_service import ListExecutionsRequest + + +class _Execution(_Resource): + """Metadata Execution resource for AI Platform""" + + _resource_noun = "executions" + _getter_method = "get_execution" + + def add_artifact( + self, artifact_resource_name: str, input: bool, + ): + """Connect Artifact to a given Execution. + + Args: + artifact_resource_name (str): + Required. The full resource name of the Artifact to connect to the Execution through an Event. + input (bool) + Required. Whether Artifact is an input event to the Execution or not. + """ + + event = Event( + artifact=artifact_resource_name, + type_=Event.Type.INPUT if input else Event.Type.OUTPUT, + ) + + self.api_client.add_execution_events( + execution=self.resource_name, events=[event], + ) + + def query_input_and_output_artifacts(self) -> Sequence[_Artifact]: + """query the input and output artifacts connected to the execution. + + Returns: + A Sequence of _Artifacts + """ + + try: + artifacts = self.api_client.query_execution_inputs_and_outputs( + execution=self.resource_name + ).artifacts + except exceptions.NotFound: + return [] + + return [ + _Artifact( + resource=artifact, + project=self.project, + location=self.location, + credentials=self.credentials, + ) + for artifact in artifacts + ] + + @classmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + gapic_execution = gca_execution.Execution( + schema_title=schema_title, + schema_version=schema_version, + display_name=display_name, + description=description, + metadata=metadata if metadata else {}, + ) + return client.create_execution( + parent=parent, execution=gapic_execution, execution_id=resource_id, + ) + + @classmethod + def _list_resources( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + filter: Optional[str] = None, + ): + """List Executions in the parent path that matches the filter. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + parent (str): + Required. The path where Executions are stored. + filter (str): + Optional. filter string to restrict the list result + """ + + list_request = ListExecutionsRequest(parent=parent, filter=filter,) + return client.list_executions(request=list_request) + + @classmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + """Update Executions with given input. + + Args: + client (utils.MetadataClientWithOverride): + Required. client to send require to Metadata Service. + resource (proto.Message): + Required. The proto.Message which contains the update information for the resource. + """ + + return client.update_execution(execution=resource) diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py new file mode 100644 index 0000000000..919eff8619 --- /dev/null +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -0,0 +1,377 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, Union, Optional + +from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata.artifact import _Artifact +from google.cloud.aiplatform.metadata.context import _Context +from google.cloud.aiplatform.metadata.execution import _Execution +from google.cloud.aiplatform.metadata.metadata_store import _MetadataStore + + +class _MetadataService: + """Contains the exposed APIs to interact with the Managed Metadata Service.""" + + def __init__(self): + self._experiment = None + self._run = None + self._metrics = None + + def reset(self): + """Reset all _MetadataService fields to None""" + self._experiment = None + self._run = None + self._metrics = None + + @property + def experiment_name(self) -> Optional[str]: + """Return the experiment name of the _MetadataService, if experiment is not set, return None""" + if self._experiment: + return self._experiment.display_name + return None + + @property + def run_name(self) -> Optional[str]: + """Return the run name of the _MetadataService, if run is not set, return None""" + if self._run: + return self._run.display_name + return None + + def set_experiment(self, experiment: str, description: Optional[str] = None): + """Setup a experiment to current session. + + Args: + experiment (str): + Required. Name of the experiment to assign current session with. + description (str): + Optional. Description of an experiment. + """ + + _MetadataStore.get_or_create() + context = _Context.get_or_create( + resource_id=experiment, + display_name=experiment, + description=description, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + if context.schema_title != constants.SYSTEM_EXPERIMENT: + raise ValueError( + f"Experiment name {experiment} has been used to create other type of resources " + f"({context.schema_title}) in this MetadataStore, please choose a different experiment name." + ) + + if description and context.description != description: + context.update(metadata=context.metadata, description=description) + + self._experiment = context + + def start_run(self, run: str): + """Setup a run to current session. + + Args: + run (str): + Required. Name of the run to assign current session with. + Raise: + ValueError if experiment is not set. Or if run execution or metrics artifact + is already created but with a different schema. + """ + + if not self._experiment: + raise ValueError( + "No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') " + "before trying to start_run. " + ) + run_execution_id = f"{self._experiment.name}-{run}" + run_execution = _Execution.get_or_create( + resource_id=run_execution_id, + display_name=run, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + if run_execution.schema_title != constants.SYSTEM_RUN: + raise ValueError( + f"Run name {run} has been used to create other type of resources ({run_execution.schema_title}) " + "in this MetadataStore, please choose a different run name." + ) + self._experiment.add_artifacts_and_executions( + execution_resource_names=[run_execution.resource_name] + ) + + metrics_artifact_id = f"{self._experiment.name}-{run}-metrics" + metrics_artifact = _Artifact.get_or_create( + resource_id=metrics_artifact_id, + display_name=metrics_artifact_id, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + if metrics_artifact.schema_title != constants.SYSTEM_METRICS: + raise ValueError( + f"Run name {run} has been used to create other type of resources ({metrics_artifact.schema_title}) " + "in this MetadataStore, please choose a different run name." + ) + run_execution.add_artifact( + artifact_resource_name=metrics_artifact.resource_name, input=False + ) + + self._run = run_execution + self._metrics = metrics_artifact + + def log_params(self, params: Dict[str, Union[float, int, str]]): + """Log single or multiple parameters with specified key and value pairs. + + Args: + params (Dict): + Required. Parameter key/value pairs. + """ + + self._validate_experiment_and_run(method_name="log_params") + # query the latest run execution resource before logging. + execution = _Execution.get_or_create( + resource_id=self._run.name, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + execution.update(metadata=params) + + def log_metrics(self, metrics: Dict[str, Union[float, int]]): + """Log single or multiple Metrics with specified key and value pairs. + + Args: + metrics (Dict): + Required. Metrics key/value pairs. Only flot and int are supported format for value. + Raises: + TypeError if value contains unsupported types. + ValueError if Experiment or Run is not set. + """ + + self._validate_experiment_and_run(method_name="log_metrics") + self._validate_metrics_value_type(metrics) + # query the latest metrics artifact resource before logging. + artifact = _Artifact.get_or_create( + resource_id=self._metrics.name, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + artifact.update(metadata=metrics) + + def get_experiment_df( + self, experiment: Optional[str] = None + ) -> "pd.DataFrame": # noqa: F821 + """Returns a Pandas DataFrame of the parameters and metrics associated with one experiment. + + Example: + + aiplatform.init(experiment='exp-1') + aiplatform.start_run(run='run-1') + aiplatform.log_params({'learning_rate': 0.1}) + aiplatform.log_metrics({'accuracy': 0.9}) + + aiplatform.start_run(run='run-2') + aiplatform.log_params({'learning_rate': 0.2}) + aiplatform.log_metrics({'accuracy': 0.95}) + + Will result in the following DataFrame + ___________________________________________________________________________ + | experiment_name | run_name | param.learning_rate | metric.accuracy | + --------------------------------------------------------------------------- + | exp-1 | run-1 | 0.1 | 0.9 | + | exp-1 | run-2 | 0.2 | 0.95 | + --------------------------------------------------------------------------- + + Args: + experiment (str): + Name of the Experiment to filter results. If not set, return results of current active experiment. + + Returns: + Pandas Dataframe of Experiment with metrics and parameters. + + Raise: + NotFound exception if experiment does not exist. + ValueError if given experiment is not associated with a wrong schema. + """ + + if not experiment: + experiment = self._experiment.name + + source = "experiment" + experiment_resource_name = self._get_experiment_or_pipeline_resource_name( + name=experiment, source=source, expected_schema=constants.SYSTEM_EXPERIMENT, + ) + + return self._query_runs_to_data_frame( + context_id=experiment, + context_resource_name=experiment_resource_name, + source=source, + ) + + def get_pipeline_df(self, pipeline: str) -> "pd.DataFrame": # noqa: F821 + """Returns a Pandas DataFrame of the parameters and metrics associated with one pipeline. + + Args: + pipeline: Name of the Pipeline to filter results. + + Returns: + Pandas Dataframe of Pipeline with metrics and parameters. + + Raise: + NotFound exception if experiment does not exist. + ValueError if given experiment is not associated with a wrong schema. + """ + + source = "pipeline" + pipeline_resource_name = self._get_experiment_or_pipeline_resource_name( + name=pipeline, source=source, expected_schema=constants.SYSTEM_PIPELINE + ) + + return self._query_runs_to_data_frame( + context_id=pipeline, + context_resource_name=pipeline_resource_name, + source=source, + ) + + def _validate_experiment_and_run(self, method_name: str): + if not self._experiment: + raise ValueError( + f"No experiment set. Make sure to call aiplatform.init(experiment='my-experiment') " + f"before trying to {method_name}. " + ) + if not self._run: + raise ValueError( + f"No run set. Make sure to call aiplatform.start_run('my-run') before trying to {method_name}. " + ) + + @staticmethod + def _validate_metrics_value_type(metrics: Dict[str, Union[float, int]]): + """Verify that metrics value are with supported types. + + Args: + metrics (Dict): + Required. Metrics key/value pairs. Only flot and int are supported format for value. + Raises: + TypeError if value contains unsupported types. + """ + + for key, value in metrics.items(): + if isinstance(value, int) or isinstance(value, float): + continue + raise TypeError( + f"metrics contain unsupported value types. key: {key}; value: {value}; type: {type(value)}" + ) + + @staticmethod + def _get_experiment_or_pipeline_resource_name( + name: str, source: str, expected_schema: str + ) -> str: + """Get the full resource name of the Context representing an Experiment or Pipeline. + + Args: + name (str): + Name of the Experiment or Pipeline. + source (str): + Identify whether the this is an Experiment or a Pipeline. + expected_schema (str): + expected_schema identifies the expected schema used for Experiment or Pipeline. + + Returns: + The full resource name of the Experiment or Pipeline Context. + + Raise: + NotFound exception if experiment or pipeline does not exist. + """ + + context = _Context(resource_name=name) + + if context.schema_title != expected_schema: + raise ValueError( + f"Please provide a valid {source} name. {name} is not a {source}." + ) + return context.resource_name + + def _query_runs_to_data_frame( + self, context_id: str, context_resource_name: str, source: str + ) -> "pd.DataFrame": # noqa: F821 + """Get metrics and parameters associated with a given Context into a Dataframe. + + Args: + context_id (str): + Name of the Experiment or Pipeline. + context_resource_name (str): + Full resource name of the Context associated with an Experiment or Pipeline. + source (str): + Identify whether the this is an Experiment or a Pipeline. + + Returns: + The full resource name of the Experiment or Pipeline Context. + """ + + filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{context_resource_name}")' + run_executions = _Execution.list(filter=filter) + + context_summary = [] + for run_execution in run_executions: + run_dict = { + f"{source}_name": context_id, + "run_name": run_execution.display_name, + } + run_dict.update( + self._execution_to_column_named_metadata( + "param", run_execution.metadata + ) + ) + + for metric_artifact in run_execution.query_input_and_output_artifacts(): + run_dict.update( + self._execution_to_column_named_metadata( + "metric", metric_artifact.metadata + ) + ) + + context_summary.append(run_dict) + + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to get dataframe as the return format. " + 'Please install the SDK using "pip install python-aiplatform[full]"' + ) + + return pd.DataFrame(context_summary) + + @staticmethod + def _execution_to_column_named_metadata( + metadata_type: str, metadata: Dict, + ) -> Dict[str, Union[int, float, str]]: + """Returns a dict of the Execution/Artifact metadata with column names. + + Args: + metadata_type: The type of this execution properties (param, metric). + metadata: Either an Execution or Artifact metadata field. + + Returns: + Dict of custom properties with keys mapped to column names + """ + + return { + ".".join([metadata_type, key]): value for key, value in metadata.items() + } + + +metadata_service = _MetadataService() diff --git a/google/cloud/aiplatform/metadata/metadata_store.py b/google/cloud/aiplatform/metadata/metadata_store.py new file mode 100644 index 0000000000..2a55f066a8 --- /dev/null +++ b/google/cloud/aiplatform/metadata/metadata_store.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from typing import Optional + +from google.api_core import exceptions +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base, initializer +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import utils +from google.cloud.aiplatform_v1beta1.types import metadata_store as gca_metadata_store + + +class _MetadataStore(base.AiPlatformResourceNounWithFutureManager): + """Managed MetadataStore resource for AI Platform""" + + client_class = utils.MetadataClientWithOverride + _is_client_prediction_client = False + _resource_noun = "metadataStores" + _getter_method = "get_metadata_store" + _delete_method = "delete_metadata_store" + + def __init__( + self, + metadata_store_name: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing MetadataStore given a MetadataStore name or ID. + + Args: + metadata_store_name (str): + Optional. A fully-qualified MetadataStore resource name or metadataStore ID. + Example: "projects/123/locations/us-central1/metadataStores/my-store" or + "my-store" when project and location are initialized or passed. + If not set, metadata_store_name will be set to "default". + project (str): + Optional project to retrieve resource from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve resource from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + """ + + super().__init__( + project=project, location=location, credentials=credentials, + ) + self._gca_resource = self._get_gca_resource(resource_name=metadata_store_name) + + @classmethod + def get_or_create( + cls, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + ) -> "_MetadataStore": + """"Retrieves or Creates (if it does not exist) a Metadata Store. + + Args: + metadata_store_id (str): + The portion of the resource name with the format: + projects/123/locations/us-central1/metadataStores/ + If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore. + project (str): + Project used to retrieve or create the metadata store. Overrides project set in + aiplatform.init. + location (str): + Location used to retrieve or create the metadata store. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to retrieve or create the metadata store. Overrides + credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the metadata store. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + + + Returns: + metadata_store (_MetadataStore): + Instantiated representation of the managed metadata store resource. + + """ + + store = cls._get( + metadata_store_name=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + if not store: + store = cls._create( + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + encryption_spec_key_name=encryption_spec_key_name, + ) + return store + + @classmethod + def _create( + cls, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + ) -> "_MetadataStore": + """Creates a new MetadataStore if it does not exist. + + Args: + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/ + If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore. + project (str): + Project used to create the metadata store. Overrides project set in + aiplatform.init. + location (str): + Location used to create the metadata store. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to create the metadata store. Overrides + credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the metadata store. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + + + Returns: + metadata_store (_MetadataStore): + Instantiated representation of the managed metadata store resource. + + """ + api_client = cls._instantiate_client(location=location, credentials=credentials) + gapic_metadata_store = gca_metadata_store.MetadataStore( + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + select_version=compat.V1BETA1, + ) + ) + + try: + api_client.create_metadata_store( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + metadata_store=gapic_metadata_store, + metadata_store_id=metadata_store_id, + ).result() + except exceptions.AlreadyExists: + logging.info(f"MetadataStore '{metadata_store_id}' already exists") + + return cls( + metadata_store_name=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + + @classmethod + def _get( + cls, + metadata_store_name: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "Optional[_MetadataStore]": + """Returns a MetadataStore resource. + + Args: + metadata_store_name (str): + Optional. A fully-qualified MetadataStore resource name or metadataStore ID. + Example: "projects/123/locations/us-central1/metadataStores/my-store" or + "my-store" when project and location are initialized or passed. + If not set, metadata_store_name will be set to "default". + project (str): + Optional project to retrieve the metadata store from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve the metadata store from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to retrieve this metadata store. Overrides + credentials set in aiplatform.init. + + Returns: + metadata_store (Optional[_MetadataStore]): + An optional instantiated representation of the managed Metadata Store resource. + """ + + try: + return cls( + metadata_store_name=metadata_store_name, + project=project, + location=location, + credentials=credentials, + ) + except exceptions.NotFound: + logging.info(f"MetadataStore {metadata_store_name} not found.") diff --git a/google/cloud/aiplatform/metadata/resource.py b/google/cloud/aiplatform/metadata/resource.py new file mode 100644 index 0000000000..11f03b7af1 --- /dev/null +++ b/google/cloud/aiplatform/metadata/resource.py @@ -0,0 +1,465 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import logging +import re +from copy import deepcopy +from typing import Optional, Dict, Union, Sequence + +import proto +from google.api_core import exceptions +from google.auth import credentials as auth_credentials +from google.protobuf import json_format + +from google.cloud.aiplatform import base, initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact +from google.cloud.aiplatform_v1beta1 import Context as GapicContext +from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution + + +class _Resource(base.AiPlatformResourceNounWithFutureManager, abc.ABC): + """Metadata Resource for AI Platform""" + + client_class = utils.MetadataClientWithOverride + _is_client_prediction_client = False + _delete_method = None + + def __init__( + self, + resource_name: Optional[str] = None, + resource: Optional[Union[GapicContext, GapicArtifact, GapicExecution]] = None, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing Metadata resource given a resource name or ID. + + Args: + resource_name (str): + A fully-qualified resource name or ID + Example: "projects/123/locations/us-central1/metadataStores/default//my-resource". + or "my-resource" when project and location are initialized or passed. if ``resource`` is provided, this + should not be set. + resource (Union[GapicContext, GapicArtifact, GapicExecution]): + The proto.Message that contains the full information of the resource. If both set, this field overrides + ``resource_name`` field. + metadata_store_id (str): + MetadataStore to retrieve resource from. If not set, metadata_store_id is set to "default". + If resource_name is a fully-qualified resource, its metadata_store_id overrides this one. + project (str): + Optional project to retrieve the resource from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve the resource from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + """ + + super().__init__( + project=project, location=location, credentials=credentials, + ) + + if resource: + self._gca_resource = resource + return + + full_resource_name = resource_name + # Construct the full_resource_name if input resource_name is the resource_id + if "/" not in resource_name: + full_resource_name = utils.full_resource_name( + resource_name=resource_name, + resource_noun=f"metadataStores/{metadata_store_id}/{self._resource_noun}", + project=self.project, + location=self.location, + ) + + self._gca_resource = getattr(self.api_client, self._getter_method)( + name=full_resource_name + ) + + @property + def metadata(self) -> Dict: + return json_format.MessageToDict(self._gca_resource._pb)["metadata"] + + @property + def schema_title(self) -> str: + return self._gca_resource.schema_title + + @property + def description(self) -> str: + return self._gca_resource.description + + @classmethod + def get_or_create( + cls, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "_Resource": + """Retrieves or Creates (if it does not exist) a Metadata resource. + + Args: + resource_id (str): + Required. The portion of the resource name with the format: + projects/123/locations/us-central1/metadataStores///. + schema_title (str): + Required. schema_title identifies the schema title used by the resource. + display_name (str): + Optional. The user-defined name of the resource. + schema_version (str): + Optional. schema_version specifies the version used by the resource. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the resource to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the resource. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to retrieve or create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to retrieve or create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to retrieve or create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (_Resource): + Instantiated representation of the managed Metadata resource. + + """ + + resource = cls._get( + resource_name=resource_id, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + if not resource: + logging.info(f"Creating Resource {resource_id}") + resource = cls._create( + resource_id=resource_id, + schema_title=schema_title, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + return resource + + def update( + self, + metadata: Dict, + description: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Updates an existing Metadata resource with new metadata. + + Args: + metadata (Dict): + Required. metadata contains the updated metadata information. + description (str): + Optional. Description describes the resource to be updated. + credentials (auth_credentials.Credentials): + Custom credentials to use to update this resource. Overrides + credentials set in aiplatform.init. + + """ + + gca_resource = deepcopy(self._gca_resource) + if gca_resource.metadata: + gca_resource.metadata.update(metadata) + else: + gca_resource.metadata = metadata + if description: + gca_resource.description = description + + api_client = self._instantiate_client(credentials=credentials) + + update_gca_resource = self._update_resource( + client=api_client, resource=gca_resource, + ) + self._gca_resource = update_gca_resource + + @classmethod + def list( + cls, + filter: Optional[str] = None, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Sequence["_Resource"]: + """List Metadata resources that match the list filter in target metadataStore. + + Args: + filter (str): + Optional. A query to filter available resources for + matching results. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resources (sequence[_Resource]): + a list of managed Metadata resource. + + """ + api_client = cls._instantiate_client(location=location, credentials=credentials) + + parent = ( + initializer.global_config.common_location_path( + project=project, location=location + ) + + f"/metadataStores/{metadata_store_id}" + ) + + try: + resources = cls._list_resources( + client=api_client, parent=parent, filter=filter, + ) + except exceptions.NotFound: + logging.info( + f"No matching resources in metadataStore: {metadata_store_id} with filter: {filter}" + ) + return [] + + return [ + cls( + resource=resource, + project=project, + location=location, + credentials=credentials, + ) + for resource in resources + ] + + @classmethod + def _create( + cls, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Creates a new Metadata resource. + + Args: + resource_id (str): + Required. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores///. + schema_title (str): + Required. schema_title identifies the schema title used by the resource. + display_name (str): + Optional. The user-defined name of the resource. + schema_version (str): + Optional. schema_version specifies the version used by the resource. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the resource to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the resource. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (_Resource): + Instantiated representation of the managed Metadata resource. + + """ + api_client = cls._instantiate_client(location=location, credentials=credentials) + + parent = ( + initializer.global_config.common_location_path( + project=project, location=location + ) + + f"/metadataStores/{metadata_store_id}" + ) + + try: + resource = cls._create_resource( + client=api_client, + parent=parent, + resource_id=resource_id, + schema_title=schema_title, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + ) + except exceptions.AlreadyExists: + logging.info(f"Resource '{resource_id}' already exist") + return + + return cls( + resource=resource, + project=project, + location=location, + credentials=credentials, + ) + + @classmethod + def _get( + cls, + resource_name: str, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Optional["_Resource"]: + """Returns a metadata Resource. + + Args: + resource_name (str): + A fully-qualified resource name or resource ID + Example: "projects/123/locations/us-central1/metadataStores/default//my-resource". + or "my-resource" when project and location are initialized or passed. + metadata_store_id (str): + The metadata_store_id portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores///my-resource + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project to get this resource from. Overrides project set in + aiplatform.init. + location (str): + Location to get this resource from. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to get this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (Optional[_Resource]): + An optional instantiated representation of the managed Metadata resource. + + """ + + try: + return cls( + resource_name=resource_name, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + except exceptions.NotFound: + logging.info(f"Resource {resource_name} not found.") + + @classmethod + @abc.abstractmethod + def _create_resource( + cls, + client: utils.MetadataClientWithOverride, + parent: str, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + ) -> proto.Message: + """Create resource method.""" + pass + + @classmethod + @abc.abstractmethod + def _update_resource( + cls, client: utils.MetadataClientWithOverride, resource: proto.Message, + ) -> proto.Message: + """Update resource method.""" + pass + + @staticmethod + def _extract_metadata_store_id(resource_name, resource_noun) -> str: + """Extracts the metadata store id from the resource name. + + Args: + resource_name (str): + Required. A fully-qualified metadata resource name. For example + projects/{project}/locations/{location}/metadataStores/{metadata_store_id}/{resource_noun}/{resource_id}. + resource_noun (str): + Required. The resource_noun portion of the resource_name + Returns: + metadata_store_id (str): + The metadata store id for the particular resource name. + Raises: + ValueError if it does not exist. + """ + pattern = re.compile( + r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/metadataStores\/(?P[\w-]+)\/" + + resource_noun + + r"\/(?P[\w-]+)$" + ) + match = pattern.match(resource_name) + if not match: + raise ValueError( + f"failed to extract metadata_store_id from resource {resource_name}" + ) + return match["store"] diff --git a/google/cloud/aiplatform/schema.py b/google/cloud/aiplatform/schema.py index 04d2f026a1..6b2a3d7d66 100644 --- a/google/cloud/aiplatform/schema.py +++ b/google/cloud/aiplatform/schema.py @@ -22,6 +22,7 @@ class training_job: class definition: custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml" automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml" + automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml" automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml" automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml" automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml" @@ -37,6 +38,7 @@ class metadata: tabular = ( "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml" ) + time_series = "gs://google-cloud-aiplatform/schema/dataset/metadata/time_series_1.0.0.yaml" image = "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml" text = "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml" video = "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml" diff --git a/google/cloud/aiplatform/tensorboard/__init__.py b/google/cloud/aiplatform/tensorboard/__init__.py new file mode 100644 index 0000000000..a6fbe4122f --- /dev/null +++ b/google/cloud/aiplatform/tensorboard/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py new file mode 100644 index 0000000000..57dcbedf60 --- /dev/null +++ b/google/cloud/aiplatform/tensorboard/uploader.py @@ -0,0 +1,1442 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Uploads a TensorBoard logdir to TensorBoard.gcp.""" +import contextlib +import functools +import json +import os +import time +import re +from typing import Callable, Dict, FrozenSet, Generator, Iterable, Optional, Tuple +import uuid + +import grpc +from tensorboard.backend import process_graph +from tensorboard.backend.event_processing.plugin_event_accumulator import ( + directory_loader, +) +from tensorboard.backend.event_processing.plugin_event_accumulator import ( + event_file_loader, +) +from tensorboard.backend.event_processing.plugin_event_accumulator import io_wrapper +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import types_pb2 +from tensorboard.plugins.graph import metadata as graph_metadata +from tensorboard.uploader import logdir_loader +from tensorboard.uploader import upload_tracker +from tensorboard.uploader import util +from tensorboard.uploader.proto import server_info_pb2 +from tensorboard.util import tb_logging +from tensorboard.util import tensor_util +import tensorflow as tf + +from google.api_core import exceptions +from google.cloud import storage +from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1 +from google.cloud.aiplatform.compat.types import ( + tensorboard_data_v1beta1 as tensorboard_data, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_experiment_v1beta1 as tensorboard_experiment, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_run_v1beta1 as tensorboard_run, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_service_v1beta1 as tensorboard_service, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_time_series_v1beta1 as tensorboard_time_series, +) +from google.protobuf import message +from google.protobuf import timestamp_pb2 as timestamp + +TensorboardServiceClient = tensorboard_service_client_v1beta1.TensorboardServiceClient + +# Minimum length of a logdir polling cycle in seconds. Shorter cycles will +# sleep to avoid spinning over the logdir, which isn't great for disks and can +# be expensive for network file systems. +_MIN_LOGDIR_POLL_INTERVAL_SECS = 1 + +# Maximum length of a base-128 varint as used to encode a 64-bit value +# (without the "msb of last byte is bit 63" optimization, to be +# compatible with protobuf and golang varints). +_MAX_VARINT64_LENGTH_BYTES = 10 + +# Default minimum interval between initiating WriteTensorbordRunData RPCs in +# milliseconds. +_DEFAULT_MIN_SCALAR_REQUEST_INTERVAL = 10 + +# Default maximum WriteTensorbordRunData request size in bytes. +_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 24 * (2 ** 10) # 24KiB + +# Default minimum interval between initiating WriteTensorbordRunData RPCs in +# milliseconds. +_DEFAULT_MIN_TENSOR_REQUEST_INTERVAL = 10 + +# Default minimum interval between initiating WriteTensorbordRunData RPCs in +# milliseconds. +_DEFAULT_MIN_BLOB_REQUEST_INTERVAL = 10 + +# Default maximum WriteTensorbordRunData request size in bytes. +_DEFAULT_MAX_TENSOR_REQUEST_SIZE = 512 * (2 ** 10) # 512KiB + +_DEFAULT_MAX_BLOB_REQUEST_SIZE = 4 * (2 ** 20) - 256 * (2 ** 10) # 4MiB-256KiB + +# Default maximum tensor point size in bytes. +_DEFAULT_MAX_TENSOR_POINT_SIZE = 16 * (2 ** 10) # 16KiB + +_DEFAULT_MAX_BLOB_SIZE = 10 * (2 ** 30) # 10GiB + +logger = tb_logging.get_logger() + + +class TensorBoardUploader(object): + """Uploads a TensorBoard logdir to TensorBoard.gcp.""" + + def __init__( + self, + experiment_name: str, + tensorboard_resource_name: str, + blob_storage_bucket: storage.Bucket, + blob_storage_folder: str, + writer_client: TensorboardServiceClient, + logdir: str, + allowed_plugins: FrozenSet[str], + experiment_display_name: Optional[str] = None, + upload_limits: Optional[server_info_pb2.UploadLimits] = None, + logdir_poll_rate_limiter: Optional[util.RateLimiter] = None, + rpc_rate_limiter: Optional[util.RateLimiter] = None, + tensor_rpc_rate_limiter: Optional[util.RateLimiter] = None, + blob_rpc_rate_limiter: Optional[util.RateLimiter] = None, + description: Optional[str] = None, + verbosity: int = 1, + one_shot: bool = False, + event_file_inactive_secs: Optional[int] = None, + run_name_prefix=None, + ): + """Constructs a TensorBoardUploader. + + Args: + experiment_name: Name of this experiment. Unique to the given + tensorboard_resource_name. + tensorboard_resource_name: Name of the Tensorboard resource with this + format + projects/{project}/locations/{location}/tensorboards/{tensorboard} + writer_client: a TensorBoardWriterService stub instance + logdir: path of the log directory to upload + experiment_display_name: The display name of the experiment. + allowed_plugins: collection of string plugin names; events will only be + uploaded if their time series's metadata specifies one of these plugin + names + upload_limits: instance of tensorboard.service.UploadLimits proto. + logdir_poll_rate_limiter: a `RateLimiter` to use to limit logdir polling + frequency, to avoid thrashing disks, especially on networked file + systems + rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency. + Note this limit applies at the level of single RPCs in the Scalar and + Tensor case, but at the level of an entire blob upload in the Blob + case-- which may require a few preparatory RPCs and a stream of chunks. + Note the chunk stream is internally rate-limited by backpressure from + the server, so it is not a concern that we do not explicitly rate-limit + within the stream here. + description: String description to assign to the experiment. + verbosity: Level of verbosity, an integer. Supported value: 0 - No upload + statistics is printed. 1 - Print upload statistics while uploading data + (default). + one_shot: Once uploading starts, upload only the existing data in the + logdir and then return immediately, instead of the default behavior of + continuing to listen for new data in the logdir and upload them when it + appears. + event_file_inactive_secs: Age in seconds of last write after which an + event file is considered inactive. If none then event file is never + considered inactive. + run_name_prefix: If present, all runs created by this invocation will have + their name prefixed by this value. + """ + self._experiment_name = experiment_name + self._experiment_display_name = experiment_display_name + self._tensorboard_resource_name = tensorboard_resource_name + self._blob_storage_bucket = blob_storage_bucket + self._blob_storage_folder = blob_storage_folder + self._api = writer_client + self._logdir = logdir + self._allowed_plugins = frozenset(allowed_plugins) + self._run_name_prefix = run_name_prefix + + self._upload_limits = upload_limits + if not self._upload_limits: + self._upload_limits = server_info_pb2.UploadLimits() + self._upload_limits.max_scalar_request_size = ( + _DEFAULT_MAX_SCALAR_REQUEST_SIZE + ) + self._upload_limits.min_scalar_request_interval = ( + _DEFAULT_MIN_SCALAR_REQUEST_INTERVAL + ) + self._upload_limits.min_tensor_request_interval = ( + _DEFAULT_MIN_TENSOR_REQUEST_INTERVAL + ) + self._upload_limits.max_tensor_request_size = ( + _DEFAULT_MAX_TENSOR_REQUEST_SIZE + ) + self._upload_limits.max_tensor_point_size = _DEFAULT_MAX_TENSOR_POINT_SIZE + self._upload_limits.min_blob_request_interval = ( + _DEFAULT_MIN_BLOB_REQUEST_INTERVAL + ) + self._upload_limits.max_blob_request_size = _DEFAULT_MAX_BLOB_REQUEST_SIZE + self._upload_limits.max_blob_size = _DEFAULT_MAX_BLOB_SIZE + + self._description = description + self._verbosity = verbosity + self._one_shot = one_shot + self._request_sender = None + if logdir_poll_rate_limiter is None: + self._logdir_poll_rate_limiter = util.RateLimiter( + _MIN_LOGDIR_POLL_INTERVAL_SECS + ) + else: + self._logdir_poll_rate_limiter = logdir_poll_rate_limiter + + if rpc_rate_limiter is None: + self._rpc_rate_limiter = util.RateLimiter( + self._upload_limits.min_scalar_request_interval / 1000 + ) + else: + self._rpc_rate_limiter = rpc_rate_limiter + + if tensor_rpc_rate_limiter is None: + self._tensor_rpc_rate_limiter = util.RateLimiter( + self._upload_limits.min_tensor_request_interval / 1000 + ) + else: + self._tensor_rpc_rate_limiter = tensor_rpc_rate_limiter + + if blob_rpc_rate_limiter is None: + self._blob_rpc_rate_limiter = util.RateLimiter( + self._upload_limits.min_blob_request_interval / 1000 + ) + else: + self._blob_rpc_rate_limiter = blob_rpc_rate_limiter + + def active_filter(secs): + return ( + not bool(event_file_inactive_secs) + or secs + event_file_inactive_secs >= time.time() + ) + + directory_loader_factory = functools.partial( + directory_loader.DirectoryLoader, + loader_factory=event_file_loader.TimestampedEventFileLoader, + path_filter=io_wrapper.IsTensorFlowEventsFile, + active_filter=active_filter, + ) + self._logdir_loader = logdir_loader.LogdirLoader( + self._logdir, directory_loader_factory + ) + self._tracker = upload_tracker.UploadTracker(verbosity=self._verbosity) + + def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperiment: + """Create an experiment or get an experiment. + + Attempts to create an experiment. If the experiment already exists and + creation fails then the experiment will be retrieved. + + Returns: + The created or retrieved experiment. + """ + logger.info("Creating experiment") + + tb_experiment = tensorboard_experiment.TensorboardExperiment( + description=self._description, display_name=self._experiment_display_name + ) + + try: + experiment = self._api.create_tensorboard_experiment( + parent=self._tensorboard_resource_name, + tensorboard_experiment=tb_experiment, + tensorboard_experiment_id=self._experiment_name, + ) + except exceptions.AlreadyExists: + logger.info("Creating experiment failed. Retrieving experiment.") + experiment_name = os.path.join( + self._tensorboard_resource_name, "experiments", self._experiment_name + ) + experiment = self._api.get_tensorboard_experiment(name=experiment_name) + return experiment + + def create_experiment(self): + """Creates an Experiment for this upload session and returns the ID.""" + + experiment = self._create_or_get_experiment() + self._experiment = experiment + self._request_sender = _BatchedRequestSender( + self._experiment.name, + self._api, + allowed_plugins=self._allowed_plugins, + upload_limits=self._upload_limits, + rpc_rate_limiter=self._rpc_rate_limiter, + tensor_rpc_rate_limiter=self._tensor_rpc_rate_limiter, + blob_rpc_rate_limiter=self._blob_rpc_rate_limiter, + blob_storage_bucket=self._blob_storage_bucket, + blob_storage_folder=self._blob_storage_folder, + tracker=self._tracker, + ) + + def get_experiment_resource_name(self): + return self._experiment.name + + def start_uploading(self): + """Blocks forever to continuously upload data from the logdir. + + Raises: + RuntimeError: If `create_experiment` has not yet been called. + ExperimentNotFoundError: If the experiment is deleted during the + course of the upload. + """ + if self._request_sender is None: + raise RuntimeError("Must call create_experiment() before start_uploading()") + while True: + self._logdir_poll_rate_limiter.tick() + self._upload_once() + if self._one_shot: + break + if self._one_shot and not self._tracker.has_data(): + logger.warning( + "One-shot mode was used on a logdir (%s) " + "without any uploadable data" % self._logdir + ) + + def _upload_once(self): + """Runs one upload cycle, sending zero or more RPCs.""" + logger.info("Starting an upload cycle") + + sync_start_time = time.time() + self._logdir_loader.synchronize_runs() + sync_duration_secs = time.time() - sync_start_time + logger.info("Logdir sync took %.3f seconds", sync_duration_secs) + + run_to_events = self._logdir_loader.get_run_events() + if self._run_name_prefix: + run_to_events = { + self._run_name_prefix + k: v for k, v in run_to_events.items() + } + with self._tracker.send_tracker(): + self._request_sender.send_requests(run_to_events) + + +class ExperimentNotFoundError(RuntimeError): + pass + + +class PermissionDeniedError(RuntimeError): + pass + + +class ExistingResourceNotFoundError(RuntimeError): + """Resource could not be created or retrieved.""" + + +class _OutOfSpaceError(Exception): + """Action could not proceed without overflowing request budget. + + This is a signaling exception (like `StopIteration`) used internally + by `_*RequestSender`; it does not mean that anything has gone wrong. + """ + + pass + + +class _BatchedRequestSender(object): + """Helper class for building requests that fit under a size limit. + + This class maintains stateful request builders for each of the possible + request types (scalars, tensors, and blobs). These accumulate batches + independently, each maintaining its own byte budget and emitting a request + when the batch becomes full. As a consequence, events of different types + will likely be sent to the backend out of order. E.g., in the extreme case, + a single tensor-flavored request may be sent only when the event stream is + exhausted, even though many more recent scalar events were sent earlier. + + This class is not threadsafe. Use external synchronization if + calling its methods concurrently. + """ + + def __init__( + self, + experiment_resource_name: str, + api: TensorboardServiceClient, + allowed_plugins: Iterable[str], + upload_limits: server_info_pb2.UploadLimits, + rpc_rate_limiter: util.RateLimiter, + tensor_rpc_rate_limiter: util.RateLimiter, + blob_rpc_rate_limiter: util.RateLimiter, + blob_storage_bucket: storage.Bucket, + blob_storage_folder: str, + tracker: upload_tracker.UploadTracker, + ): + """Constructs _BatchedRequestSender for the given experiment resource. + + Args: + experiment_resource_name: Name of the experiment resource of the form + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment} + api: Tensorboard service stub used to interact with experiment resource. + allowed_plugins: The plugins supported by the Tensorboard.gcp resource. + upload_limits: Upload limits for for api calls. + rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency. + Note this limit applies at the level of single RPCs in the Scalar and + Tensor case, but at the level of an entire blob upload in the Blob + case-- which may require a few preparatory RPCs and a stream of chunks. + Note the chunk stream is internally rate-limited by backpressure from + the server, so it is not a concern that we do not explicitly rate-limit + within the stream here. + tracker: Upload tracker to track information about uploads. + """ + self._experiment_resource_name = experiment_resource_name + self._api = api + self._tag_metadata = {} + self._allowed_plugins = frozenset(allowed_plugins) + self._tracker = tracker + self._run_to_request_sender: Dict[str, _ScalarBatchedRequestSender] = {} + self._run_to_tensor_request_sender: Dict[str, _TensorBatchedRequestSender] = {} + self._run_to_blob_request_sender: Dict[str, _BlobRequestSender] = {} + self._run_to_run_resource: Dict[str, tensorboard_run.TensorboardRun] = {} + self._scalar_request_sender_factory = functools.partial( + _ScalarBatchedRequestSender, + api=api, + rpc_rate_limiter=rpc_rate_limiter, + max_request_size=upload_limits.max_scalar_request_size, + tracker=self._tracker, + ) + self._tensor_request_sender_factory = functools.partial( + _TensorBatchedRequestSender, + api=api, + rpc_rate_limiter=tensor_rpc_rate_limiter, + max_request_size=upload_limits.max_tensor_request_size, + max_tensor_point_size=upload_limits.max_tensor_point_size, + tracker=self._tracker, + ) + self._blob_request_sender_factory = functools.partial( + _BlobRequestSender, + api=api, + rpc_rate_limiter=blob_rpc_rate_limiter, + max_blob_request_size=upload_limits.max_blob_request_size, + max_blob_size=upload_limits.max_blob_size, + blob_storage_bucket=blob_storage_bucket, + blob_storage_folder=blob_storage_folder, + tracker=self._tracker, + ) + + def send_requests( + self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] + ): + """Accepts a stream of TF events and sends batched write RPCs. + + Each sent request will be batched, the size of each batch depending on + the type of data (Scalar vs Tensor vs Blob) being sent. + + Args: + run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` + values, as returned by `LogdirLoader.get_run_events`. + + Raises: + RuntimeError: If no progress can be made because even a single + point is too large (say, due to a gigabyte-long tag name). + """ + + for (run_name, event, value) in self._run_values(run_to_events): + time_series_key = (run_name, value.tag) + + # The metadata for a time series is memorized on the first event. + # If later events arrive with a mismatching plugin_name, they are + # ignored with a warning. + metadata = self._tag_metadata.get(time_series_key) + first_in_time_series = False + if metadata is None: + first_in_time_series = True + metadata = value.metadata + self._tag_metadata[time_series_key] = metadata + + plugin_name = metadata.plugin_data.plugin_name + if value.HasField("metadata") and ( + plugin_name != value.metadata.plugin_data.plugin_name + ): + logger.warning( + "Mismatching plugin names for %s. Expected %s, found %s.", + time_series_key, + metadata.plugin_data.plugin_name, + value.metadata.plugin_data.plugin_name, + ) + continue + if plugin_name not in self._allowed_plugins: + if first_in_time_series: + logger.info( + "Skipping time series %r with unsupported plugin name %r", + time_series_key, + plugin_name, + ) + continue + self._tracker.add_plugin_name(plugin_name) + # If this is the first time we've seen this run create a new run resource + # and an associated request sender. + if run_name not in self._run_to_run_resource: + self._create_or_get_run_resource(run_name) + self._run_to_request_sender[ + run_name + ] = self._scalar_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_tensor_request_sender[ + run_name + ] = self._tensor_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_blob_request_sender[ + run_name + ] = self._blob_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + + if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: + self._run_to_request_sender[run_name].add_event(event, value, metadata) + elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: + self._run_to_tensor_request_sender[run_name].add_event( + event, value, metadata + ) + elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: + self._run_to_blob_request_sender[run_name].add_event( + event, value, metadata + ) + + for scalar_request_sender in self._run_to_request_sender.values(): + scalar_request_sender.flush() + + for tensor_request_sender in self._run_to_tensor_request_sender.values(): + tensor_request_sender.flush() + + for blob_request_sender in self._run_to_blob_request_sender.values(): + blob_request_sender.flush() + + def _create_or_get_run_resource(self, run_name: str): + """Creates a new Run Resource in current Tensorboard Experiment resource. + + Args: + run_name: The display name of this run. + """ + tb_run = tensorboard_run.TensorboardRun() + tb_run.display_name = run_name + try: + tb_run = self._api.create_tensorboard_run( + parent=self._experiment_resource_name, + tensorboard_run=tb_run, + tensorboard_run_id=str(uuid.uuid4()), + ) + except exceptions.InvalidArgument as e: + # If the run name already exists then retrieve it + if "already exist" in e.message: + runs_pages = self._api.list_tensorboard_runs( + parent=self._experiment_resource_name + ) + for tb_run in runs_pages: + if tb_run.display_name == run_name: + break + + if tb_run.display_name != run_name: + raise ExistingResourceNotFoundError( + "Run with name %s already exists but is not resource list." + % run_name + ) + else: + raise + + self._run_to_run_resource[run_name] = tb_run + + def _run_values( + self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] + ) -> Generator[ + Tuple[str, tf.compat.v1.Event, tf.compat.v1.Summary.Value], None, None + ]: + """Helper generator to create a single stream of work items. + + Note that `dataclass_compat` may emit multiple variants of + the same event, for backwards compatibility. Thus this stream should + be filtered to obtain the desired version of each event. Here, we + ignore any event that does not have a `summary` field. + + Furthermore, the events emitted here could contain values that do not + have `metadata.data_class` set; these too should be ignored. In + `_send_summary_value(...)` above, we switch on `metadata.data_class` + and drop any values with an unknown (i.e., absent or unrecognized) + `data_class`. + + Args: + run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` + values, as returned by `LogdirLoader.get_run_events`. + + Yields: + Tuple of run name, tf.compat.v1.Event, tf.compat.v1.Summary.Value per + value. + """ + # Note that this join in principle has deletion anomalies: if the input + # stream contains runs with no events, or events with no values, we'll + # lose that information. This is not a problem: we would need to prune + # such data from the request anyway. + for (run_name, events) in run_to_events.items(): + for event in events: + _filter_graph_defs(event) + for value in event.summary.value: + yield (run_name, event, value) + + +class _TimeSeriesResourceManager(object): + """Helper class managing Time Series resources.""" + + def __init__(self, run_resource_id: str, api: TensorboardServiceClient): + """Constructor for _TimeSeriesResourceManager. + + Args: + run_resource_id: The resource id for the run with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + api: TensorboardServiceStub + """ + self._run_resource_id = run_resource_id + self._api = api + self._tag_to_time_series_proto: Dict[ + str, tensorboard_time_series.TensorboardTimeSeries + ] = {} + + def get_or_create( + self, + tag_name: str, + time_series_resource_creator: Callable[ + [], tensorboard_time_series.TensorboardTimeSeries + ], + ) -> tensorboard_time_series.TensorboardTimeSeries: + """get a time series resource with given tag_name, and create a new one on + + OnePlatform if not present. + + Args: + tag_name: The tag name of the time series in the Tensorboard log dir. + time_series_resource_creator: A callable that produces a TimeSeries for + creation. + """ + if tag_name in self._tag_to_time_series_proto: + return self._tag_to_time_series_proto[tag_name] + + time_series = time_series_resource_creator() + time_series.display_name = tag_name + try: + time_series = self._api.create_tensorboard_time_series( + parent=self._run_resource_id, tensorboard_time_series=time_series + ) + except exceptions.InvalidArgument as e: + # If the time series display name already exists then retrieve it + if "already exist" in e.message: + list_of_time_series = self._api.list_tensorboard_time_series( + request=tensorboard_service.ListTensorboardTimeSeriesRequest( + parent=self._run_resource_id, + filter="display_name = {}".format(json.dumps(str(tag_name))), + ) + ) + num = 0 + for ts in list_of_time_series: + time_series = ts + num += 1 + break + if num != 1: + raise ValueError( + "More than one time series resource found with display_name: {}".format( + tag_name + ) + ) + else: + raise + + self._tag_to_time_series_proto[tag_name] = time_series + return time_series + + +class _ScalarBatchedRequestSender(object): + """Helper class for building requests that fit under a size limit. + + This class accumulates a current request. `add_event(...)` may or may not + send the request (and start a new one). After all `add_event(...)` calls + are complete, a final call to `flush()` is needed to send the final request. + + This class is not threadsafe. Use external synchronization if calling its + methods concurrently. + """ + + def __init__( + self, + run_resource_id: str, + api: TensorboardServiceClient, + rpc_rate_limiter: util.RateLimiter, + max_request_size: int, + tracker: upload_tracker.UploadTracker, + ): + """Constructer for _ScalarBatchedRequestSender. + + Args: + run_resource_id: The resource id for the run with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + api: TensorboardServiceStub + rpc_rate_limiter: until.RateLimiter to limit rate of this request sender + max_request_size: max number of bytes to send + tracker: + """ + self._run_resource_id = run_resource_id + self._api = api + self._rpc_rate_limiter = rpc_rate_limiter + self._byte_budget_manager = _ByteBudgetManager(max_request_size) + self._tracker = tracker + + # cache: map from Tensorboard tag to TimeSeriesData + # cleared whenever a new request is created + self._tag_to_time_series_data: Dict[str, tensorboard_data.TimeSeriesData] = {} + + self._time_series_resource_manager = _TimeSeriesResourceManager( + self._run_resource_id, self._api + ) + self._new_request() + + def _new_request(self): + """Allocates a new request and refreshes the budget.""" + self._request = tensorboard_service.WriteTensorboardRunDataRequest() + self._tag_to_time_series_data.clear() + self._num_values = 0 + self._request.tensorboard_run = self._run_resource_id + self._byte_budget_manager.reset(self._request) + + def add_event( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + """Attempts to add the given event to the current request. + + If the event cannot be added to the current request because the byte + budget is exhausted, the request is flushed, and the event is added + to the next request. + + Args: + event: The tf.compat.v1.Event event containing the value. + value: A scalar tf.compat.v1.Summary.Value. + metadata: SummaryMetadata of the event. + """ + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + self.flush() + # Try again. This attempt should never produce OutOfSpaceError + # because we just flushed. + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + raise RuntimeError("add_event failed despite flush") + + def _add_event_internal( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + self._num_values += 1 + time_series_data_proto = self._tag_to_time_series_data.get(value.tag) + if time_series_data_proto is None: + time_series_data_proto = self._create_time_series_data(value.tag, metadata) + self._create_point(time_series_data_proto, event, value) + + def flush(self): + """Sends the active request after removing empty runs and tags. + + Starts a new, empty active request. + """ + request = self._request + request.time_series_data = list(self._tag_to_time_series_data.values()) + _prune_empty_time_series(request) + if not request.time_series_data: + return + + self._rpc_rate_limiter.tick() + + with _request_logger(request): + with self._tracker.scalars_tracker(self._num_values): + try: + self._api.write_tensorboard_run_data( + tensorboard_run=self._run_resource_id, + time_series_data=request.time_series_data, + ) + except grpc.RpcError as e: + if ( + hasattr(e, "code") + and getattr(e, "code")() == grpc.StatusCode.NOT_FOUND + ): + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) + + self._new_request() + + def _create_time_series_data( + self, tag_name: str, metadata: tf.compat.v1.SummaryMetadata + ) -> tensorboard_data.TimeSeriesData: + """Adds a time_series for the tag_name, if there's space. + + Args: + tag_name: String name of the tag to add (as `value.tag`). + + Returns: + The TimeSeriesData in _request proto with the given tag name. + + Raises: + _OutOfSpaceError: If adding the tag would exceed the remaining + request budget. + """ + time_series_data_proto = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id=self._time_series_resource_manager.get_or_create( + tag_name, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=tag_name, + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + plugin_name=metadata.plugin_data.plugin_name, + plugin_data=metadata.plugin_data.content, + ), + ).name.split("/")[-1], + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR, + ) + + self._request.time_series_data.extend([time_series_data_proto]) + self._byte_budget_manager.add_time_series(time_series_data_proto) + self._tag_to_time_series_data[tag_name] = time_series_data_proto + return time_series_data_proto + + def _create_point( + self, + time_series_proto: tensorboard_data.TimeSeriesData, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + ): + """Adds a scalar point to the given tag, if there's space. + + Args: + time_series_proto: TimeSeriesData proto to which to add a point. + event: Enclosing `Event` proto with the step and wall time data. + value: Scalar `Summary.Value` proto with the actual scalar data. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining + request budget. + """ + scalar_proto = tensorboard_data.Scalar( + value=tensor_util.make_ndarray(value.tensor).item() + ) + point = tensorboard_data.TimeSeriesDataPoint( + step=event.step, + scalar=scalar_proto, + wall_time=timestamp.Timestamp( + seconds=int(event.wall_time), + nanos=int(round((event.wall_time % 1) * 10 ** 9)), + ), + ) + time_series_proto.values.extend([point]) + try: + self._byte_budget_manager.add_point(point) + except _OutOfSpaceError: + time_series_proto.values.pop() + raise + + +class _TensorBatchedRequestSender(object): + """Helper class for building WriteTensor() requests that fit under a size limit. + + This class accumulates a current request. `add_event(...)` may or may not + send the request (and start a new one). After all `add_event(...)` calls + are complete, a final call to `flush()` is needed to send the final request. + This class is not threadsafe. Use external synchronization if calling its + methods concurrently. + """ + + def __init__( + self, + run_resource_id: str, + api: TensorboardServiceClient, + rpc_rate_limiter: util.RateLimiter, + max_request_size: int, + max_tensor_point_size: int, + tracker: upload_tracker.UploadTracker, + ): + """Constructer for _TensorBatchedRequestSender. + + Args: + run_resource_id: The resource id for the run with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + api: TensorboardServiceStub + rpc_rate_limiter: until.RateLimiter to limit rate of this request sender + max_request_size: max number of bytes to send + tracker: + """ + self._run_resource_id = run_resource_id + self._api = api + self._rpc_rate_limiter = rpc_rate_limiter + self._byte_budget_manager = _ByteBudgetManager(max_request_size) + self._max_tensor_point_size = max_tensor_point_size + self._tracker = tracker + + # cache: map from Tensorboard tag to TimeSeriesData + # cleared whenever a new request is created + self._tag_to_time_series_data: Dict[str, tensorboard_data.TimeSeriesData] = {} + + self._time_series_resource_manager = _TimeSeriesResourceManager( + run_resource_id, api + ) + self._new_request() + + def _new_request(self): + """Allocates a new request and refreshes the budget.""" + self._request = tensorboard_service.WriteTensorboardRunDataRequest() + self._tag_to_time_series_data.clear() + self._num_values = 0 + self._request.tensorboard_run = self._run_resource_id + self._byte_budget_manager.reset(self._request) + self._num_values = 0 + self._num_values_skipped = 0 + self._tensor_bytes = 0 + self._tensor_bytes_skipped = 0 + + def add_event( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + """Attempts to add the given event to the current request. + + If the event cannot be added to the current request because the byte + budget is exhausted, the request is flushed, and the event is added + to the next request. + """ + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + self.flush() + # Try again. This attempt should never produce OutOfSpaceError + # because we just flushed. + try: + self._add_event_internal(event, value, metadata) + except _OutOfSpaceError: + raise RuntimeError("add_event failed despite flush") + + def _add_event_internal( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + self._num_values += 1 + time_series_data_proto = self._tag_to_time_series_data.get(value.tag) + if time_series_data_proto is None: + time_series_data_proto = self._create_time_series_data(value.tag, metadata) + self._create_point(time_series_data_proto, event, value) + + def flush(self): + """Sends the active request after removing empty runs and tags. + + Starts a new, empty active request. + """ + request = self._request + request.time_series_data = list(self._tag_to_time_series_data.values()) + _prune_empty_time_series(request) + if not request.time_series_data: + return + + self._rpc_rate_limiter.tick() + + with _request_logger(request): + with self._tracker.tensors_tracker( + self._num_values, + self._num_values_skipped, + self._tensor_bytes, + self._tensor_bytes_skipped, + ): + try: + self._api.write_tensorboard_run_data( + tensorboard_run=self._run_resource_id, + time_series_data=request.time_series_data, + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) + + self._new_request() + + def _create_time_series_data( + self, tag_name: str, metadata: tf.compat.v1.SummaryMetadata + ) -> tensorboard_data.TimeSeriesData: + """Adds a time_series for the tag_name, if there's space. + + Args: + tag_name: String name of the tag to add (as `value.tag`). + metadata: SummaryMetadata of the event. + + Returns: + The TimeSeriesData in _request proto with the given tag name. + + Raises: + _OutOfSpaceError: If adding the tag would exceed the remaining + request budget. + """ + time_series_data_proto = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id=self._time_series_resource_manager.get_or_create( + tag_name, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=tag_name, + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR, + plugin_name=metadata.plugin_data.plugin_name, + plugin_data=metadata.plugin_data.content, + ), + ).name.split("/")[-1], + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR, + ) + + self._request.time_series_data.extend([time_series_data_proto]) + self._byte_budget_manager.add_time_series(time_series_data_proto) + self._tag_to_time_series_data[tag_name] = time_series_data_proto + return time_series_data_proto + + def _create_point( + self, + time_series_proto: tensorboard_data.TimeSeriesData, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + ): + """Adds a tensor point to the given tag, if there's space. + + Args: + tag_proto: `WriteTensorRequest.Tag` proto to which to add a point. + event: Enclosing `Event` proto with the step and wall time data. + value: Tensor `Summary.Value` proto with the actual tensor data. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining + request budget. + """ + point = tensorboard_data.TimeSeriesDataPoint( + step=event.step, + tensor=tensorboard_data.TensorboardTensor( + value=value.tensor.SerializeToString() + ), + wall_time=timestamp.Timestamp( + seconds=int(event.wall_time), + nanos=int(round((event.wall_time % 1) * 10 ** 9)), + ), + ) + + self._num_values += 1 + tensor_size = len(point.tensor.value) + self._tensor_bytes += tensor_size + if tensor_size > self._max_tensor_point_size: + logger.warning( + "Tensor too large; skipping. " "Size %d exceeds limit of %d bytes.", + tensor_size, + self._max_tensor_point_size, + ) + self._num_values_skipped += 1 + self._tensor_bytes_skipped += tensor_size + return + + self._validate_tensor_value( + value.tensor, value.tag, event.step, event.wall_time + ) + + time_series_proto.values.extend([point]) + + try: + self._byte_budget_manager.add_point(point) + except _OutOfSpaceError: + time_series_proto.values.pop() + raise + + def _validate_tensor_value(self, tensor_proto, tag, step, wall_time): + """Validate a TensorProto by attempting to parse it.""" + try: + tensor_util.make_ndarray(tensor_proto) + except ValueError as error: + raise ValueError( + "The uploader failed to upload a tensor. This seems to be " + "due to a malformation in the tensor, which may be caused by " + "a bug in the process that wrote the tensor.\n\n" + "The tensor has tag '%s' and is at step %d and wall_time %.6f.\n\n" + "Original error:\n%s" % (tag, step, wall_time, error) + ) + + +class _ByteBudgetManager(object): + """Helper class for managing the request byte budget for certain RPCs. + + This should be used for RPCs that organize data by Runs, Tags, and Points, + specifically WriteScalar and WriteTensor. + + Any call to add_time_series() or add_point() may raise an + _OutOfSpaceError, which is non-fatal. It signals to the caller that they + should flush the current request and begin a new one. + + For more information on the protocol buffer encoding and how byte cost + can be calculated, visit: + + https://developers.google.com/protocol-buffers/docs/encoding + """ + + def __init__(self, max_bytes: int): + # The remaining number of bytes that we may yet add to the request. + self._byte_budget = None # type: int + self._max_bytes = max_bytes + + def reset(self, base_request: tensorboard_service.WriteTensorboardRunDataRequest): + """Resets the byte budget and calculates the cost of the base request. + + Args: + base_request: Base request. + + Raises: + _OutOfSpaceError: If the size of the request exceeds the entire + request byte budget. + """ + self._byte_budget = self._max_bytes + self._byte_budget -= ( + base_request._pb.ByteSize() + ) # pylint: disable=protected-access + if self._byte_budget < 0: + raise _OutOfSpaceError("Byte budget too small for base request") + + def add_time_series(self, time_series_proto: tensorboard_data.TimeSeriesData): + """Integrates the cost of a tag proto into the byte budget. + + Args: + time_series_proto: The proto representing a time series. + + Raises: + _OutOfSpaceError: If adding the time_series would exceed the remaining + request budget. + """ + cost = ( + # The size of the tag proto without any tag fields set. + time_series_proto._pb.ByteSize() # pylint: disable=protected-access + # The size of the varint that describes the length of the tag + # proto. We can't yet know the final size of the tag proto -- we + # haven't yet set any point values -- so we can't know the final + # size of this length varint. We conservatively assume it is maximum + # size. + + _MAX_VARINT64_LENGTH_BYTES + # The size of the proto key. + + 1 + ) + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + + def add_point(self, point_proto: tensorboard_data.TimeSeriesDataPoint): + """Integrates the cost of a point proto into the byte budget. + + Args: + point_proto: The proto representing a point. + + Raises: + _OutOfSpaceError: If adding the point would exceed the remaining request + budget. + """ + submessage_cost = point_proto._pb.ByteSize() # pylint: disable=protected-access + cost = ( + # The size of the point proto. + submessage_cost + # The size of the varint that describes the length of the point + # proto. + + _varint_cost(submessage_cost) + # The size of the proto key. + + 1 + ) + if cost > self._byte_budget: + raise _OutOfSpaceError() + self._byte_budget -= cost + + +class _BlobRequestSender(object): + """Uploader for blob-type event data. + + Unlike the other types, this class does not accumulate events in batches; + every blob is sent individually and immediately. Nonetheless we retain + the `add_event()`/`flush()` structure for symmetry. + + This class is not threadsafe. Use external synchronization if calling its + methods concurrently. + """ + + def __init__( + self, + run_resource_id: str, + api: TensorboardServiceClient, + rpc_rate_limiter: util.RateLimiter, + max_blob_request_size: int, + max_blob_size: int, + blob_storage_bucket: storage.Bucket, + blob_storage_folder: str, + tracker: upload_tracker.UploadTracker, + ): + self._run_resource_id = run_resource_id + self._api = api + self._rpc_rate_limiter = rpc_rate_limiter + self._max_blob_request_size = max_blob_request_size + self._max_blob_size = max_blob_size + self._tracker = tracker + self._time_series_resource_manager = _TimeSeriesResourceManager( + run_resource_id, api + ) + + self._bucket = blob_storage_bucket + self._folder = blob_storage_folder + + self._new_request() + + def _new_request(self): + """Declares the previous event complete.""" + self._event = None + self._value = None + self._metadata = None + + def add_event( + self, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, + metadata: tf.compat.v1.SummaryMetadata, + ): + """Attempts to add the given event to the current request. + + If the event cannot be added to the current request because the byte + budget is exhausted, the request is flushed, and the event is added + to the next request. + """ + if self._value: + raise RuntimeError("Tried to send blob while another is pending") + self._event = event # provides step and possibly plugin_name + self._value = value + self._blobs = tensor_util.make_ndarray(self._value.tensor) + if self._blobs.ndim == 1: + self._metadata = metadata + self.flush() + else: + logger.warning( + "A blob sequence must be represented as a rank-1 Tensor. " + "Provided data has rank %d, for run %s, tag %s, step %s ('%s' plugin) .", + self._blobs.ndim, + self._run_resource_id, + self._value.tag, + self._event.step, + metadata.plugin_data.plugin_name, + ) + # Skip this upload. + self._new_request() + + def flush(self): + """Sends the current blob sequence fully, and clears it to make way for the next.""" + if not self._value: + self._new_request() + return + + time_series_proto = self._time_series_resource_manager.get_or_create( + self._value.tag, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=self._value.tag, + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE, + plugin_name=self._metadata.plugin_data.plugin_name, + plugin_data=self._metadata.plugin_data.content, + ), + ) + m = re.match( + ".*/tensorboards/(.*)/experiments/(.*)/runs/(.*)/timeSeries/(.*)", + time_series_proto.name, + ) + blob_path_prefix = "tensorboard-{}/{}/{}/{}".format(m[1], m[2], m[3], m[4]) + blob_path_prefix = ( + "{}/{}".format(self._folder, blob_path_prefix) + if self._folder + else blob_path_prefix + ) + sent_blob_ids = [] + for blob in self._blobs: + self._rpc_rate_limiter.tick() + with self._tracker.blob_tracker(len(blob)) as blob_tracker: + blob_id = self._send_blob(blob, blob_path_prefix) + if blob_id is not None: + sent_blob_ids.append(str(blob_id)) + blob_tracker.mark_uploaded(blob_id is not None) + + data_point = tensorboard_data.TimeSeriesDataPoint( + step=self._event.step, + blobs=tensorboard_data.TensorboardBlobSequence( + values=[ + tensorboard_data.TensorboardBlob(id=blob_id) + for blob_id in sent_blob_ids + ] + ), + wall_time=timestamp.Timestamp( + seconds=int(self._event.wall_time), + nanos=int(round((self._event.wall_time % 1) * 10 ** 9)), + ), + ) + + time_series_data_proto = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id=time_series_proto.name.split("/")[-1], + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE, + values=[data_point], + ) + request = tensorboard_service.WriteTensorboardRunDataRequest( + time_series_data=[time_series_data_proto] + ) + + _prune_empty_time_series(request) + if not request.time_series_data: + return + + with _request_logger(request): + try: + self._api.write_tensorboard_run_data( + tensorboard_run=self._run_resource_id, + time_series_data=request.time_series_data, + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + raise ExperimentNotFoundError() + logger.error("Upload call failed with error %s", e) + + self._new_request() + + def _send_blob(self, blob, blob_path_prefix): + """Sends a single blob to a GCS bucket in the consumer project. + + The blob will not be sent if it is too large. + + Returns: + The ID of blob successfully sent. + """ + if len(blob) > self._max_blob_size: + logger.warning( + "Blob too large; skipping. Size %d exceeds limit of %d bytes.", + len(blob), + self._max_blob_size, + ) + return None + + blob_id = uuid.uuid4() + blob_path = ( + "{}/{}".format(blob_path_prefix, blob_id) if blob_path_prefix else blob_id + ) + self._bucket.blob(blob_path).upload_from_string(blob) + return blob_id + + +@contextlib.contextmanager +def _request_logger(request: tensorboard_service.WriteTensorboardRunDataRequest): + """Context manager to log request size and duration.""" + upload_start_time = time.time() + request_bytes = request._pb.ByteSize() # pylint: disable=protected-access + logger.info("Trying request of %d bytes", request_bytes) + yield + upload_duration_secs = time.time() - upload_start_time + logger.info( + "Upload of (%d bytes) took %.3f seconds", request_bytes, upload_duration_secs, + ) + + +def _varint_cost(n: int): + """Computes the size of `n` encoded as an unsigned base-128 varint. + + This should be consistent with the proto wire format: + + + Args: + n: A non-negative integer. + + Returns: + An integer number of bytes. + """ + result = 1 + while n >= 128: + result += 1 + n >>= 7 + return result + + +def _prune_empty_time_series( + request: tensorboard_service.WriteTensorboardRunDataRequest, +): + """Removes empty time_series from request.""" + for (time_series_idx, time_series_data) in reversed( + list(enumerate(request.time_series_data)) + ): + if not time_series_data.values: + del request.time_series_data[time_series_idx] + + +def _filter_graph_defs(event: tf.compat.v1.Event): + """Filters graph definitions. + + Args: + event: tf.compat.v1.Event to filter. + """ + for v in event.summary.value: + if v.metadata.plugin_data.plugin_name != graph_metadata.PLUGIN_NAME: + continue + if v.tag == graph_metadata.RUN_GRAPH_NAME: + data = list(v.tensor.string_val) + filtered_data = [_filtered_graph_bytes(x) for x in data] + filtered_data = [x for x in filtered_data if x is not None] + if filtered_data != data: + new_tensor = tensor_util.make_tensor_proto( + filtered_data, dtype=types_pb2.DT_STRING + ) + v.tensor.CopyFrom(new_tensor) + + +def _filtered_graph_bytes(graph_bytes: bytes): + """Prepares the graph to be served to the front-end. + + For now, it supports filtering out attributes that are too large to be shown + in the graph UI. + + Args: + graph_bytes: Graph definition. + + Returns: + Filtered graph. + """ + try: + graph_def = graph_pb2.GraphDef().FromString(graph_bytes) + # The reason for the RuntimeWarning catch here is b/27494216, whereby + # some proto parsers incorrectly raise that instead of DecodeError + # on certain kinds of malformed input. Triggering this seems to require + # a combination of mysterious circumstances. + except (message.DecodeError, RuntimeWarning): + logger.warning( + "Could not parse GraphDef of size %d. Skipping.", len(graph_bytes), + ) + return None + # Use the default filter parameters: + # limit_attr_size=1024, large_attrs_key="_too_large_attrs" + process_graph.prepare_graph_for_ui(graph_def) + return graph_def.SerializeToString() diff --git a/google/cloud/aiplatform/tensorboard/uploader_main.py b/google/cloud/aiplatform/tensorboard/uploader_main.py new file mode 100644 index 0000000000..734d647fb4 --- /dev/null +++ b/google/cloud/aiplatform/tensorboard/uploader_main.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Launches Tensorboard Uploader for TB.GCP.""" +import re + +from absl import app +from absl import flags +import grpc +from tensorboard.plugins.scalar import metadata as scalar_metadata +from tensorboard.plugins.distribution import metadata as distribution_metadata +from tensorboard.plugins.histogram import metadata as histogram_metadata +from tensorboard.plugins.text import metadata as text_metadata +from tensorboard.plugins.hparams import metadata as hparams_metadata +from tensorboard.plugins.image import metadata as images_metadata +from tensorboard.plugins.graph import metadata as graphs_metadata + +from google.cloud import storage +from google.cloud import aiplatform +from google.cloud.aiplatform.tensorboard import uploader +from google.cloud.aiplatform.utils import TensorboardClientWithOverride + +FLAGS = flags.FLAGS +flags.DEFINE_string("experiment_name", None, "The name of the Cloud AI Experiment.") +flags.DEFINE_string( + "experiment_display_name", None, "The display name of the Cloud AI Experiment." +) +flags.DEFINE_string("logdir", None, "Tensorboard log directory to upload") +flags.DEFINE_bool("one_shot", False, "Iterate through logdir once to upload.") +flags.DEFINE_string("env", "prod", "Environment which this tensorboard belongs to.") +flags.DEFINE_string( + "tensorboard_resource_name", + None, + "Tensorboard resource to create this experiment in. ", +) +flags.DEFINE_integer( + "event_file_inactive_secs", + None, + "Age in seconds of last write after which an event file is considered " "inactive.", +) +flags.DEFINE_string( + "run_name_prefix", + None, + "If present, all runs created by this invocation will have their name " + "prefixed by this value.", +) + +flags.DEFINE_multi_string( + "allowed_plugins", + [ + scalar_metadata.PLUGIN_NAME, + histogram_metadata.PLUGIN_NAME, + distribution_metadata.PLUGIN_NAME, + text_metadata.PLUGIN_NAME, + hparams_metadata.PLUGIN_NAME, + images_metadata.PLUGIN_NAME, + graphs_metadata.PLUGIN_NAME, + ], + "Plugins allowed by the Uploader.", +) + +flags.mark_flags_as_required(["experiment_name", "logdir", "tensorboard_resource_name"]) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + m = re.match( + "projects/(.*)/locations/(.*)/tensorboards/.*", FLAGS.tensorboard_resource_name + ) + project_id = m[1] + region = m[2] + api_client = aiplatform.initializer.global_config.create_client( + client_class=TensorboardClientWithOverride, location_override=region, + ) + + try: + tensorboard = api_client.get_tensorboard(name=FLAGS.tensorboard_resource_name) + except grpc.RpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.NOT_FOUND: + raise app.UsageError( + "Tensorboard resource %s not found" % FLAGS.tensorboard_resource_name, + exitcode=0, + ) + raise + + if tensorboard.blob_storage_path_prefix: + path_prefix = tensorboard.blob_storage_path_prefix + "/" + first_slash_index = path_prefix.find("/") + bucket_name = path_prefix[:first_slash_index] + blob_storage_bucket = storage.Client(project=project_id).bucket(bucket_name) + blob_storage_folder = path_prefix[first_slash_index + 1 :] + else: + raise app.UsageError( + "Tensorboard resource {} is obsolete. Please create a new one.".format( + FLAGS.tensorboard_resource_name + ), + exitcode=0, + ) + + tb_uploader = uploader.TensorBoardUploader( + experiment_name=FLAGS.experiment_name, + experiment_display_name=FLAGS.experiment_display_name, + tensorboard_resource_name=tensorboard.name, + blob_storage_bucket=blob_storage_bucket, + blob_storage_folder=blob_storage_folder, + allowed_plugins=FLAGS.allowed_plugins, + writer_client=api_client, + logdir=FLAGS.logdir, + one_shot=FLAGS.one_shot, + event_file_inactive_secs=FLAGS.event_file_inactive_secs, + run_name_prefix=FLAGS.run_name_prefix, + ) + + tb_uploader.create_experiment() + + print( + "View your Tensorboard at https://{}.{}/experiment/{}".format( + region, + "tensorboard.googleusercontent.com", + tb_uploader.get_experiment_resource_name().replace("/", "+"), + ) + ) + if FLAGS.one_shot: + tb_uploader._upload_once() # pylint: disable=protected-access + else: + tb_uploader.start_uploading() + + +def run_main(): + app.run(main) + + +if __name__ == "__main__": + run_main() diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index f3f447deb6..2912806a12 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -281,7 +281,7 @@ def _create_input_data_config( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. gcs_destination_uri_prefix (str): Optional. The Google Cloud Storage location. @@ -322,12 +322,12 @@ def _create_input_data_config( # Create predefined split spec predefined_split = None if predefined_split_column_name: - if ( - dataset._gca_resource.metadata_schema_uri - != schema.dataset.metadata.tabular + if dataset._gca_resource.metadata_schema_uri not in ( + schema.dataset.metadata.tabular, + schema.dataset.metadata.time_series, ): raise ValueError( - "A pre-defined split may only be used with a tabular Dataset" + "A pre-defined split may only be used with a tabular or time series Dataset" ) predefined_split = gca_training_pipeline.PredefinedSplit( @@ -440,7 +440,7 @@ def _run_job( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. model (~.model.Model): Optional. Describes the Model that may be uploaded (via [ModelService.UploadMode][]) by this TrainingPipeline. The @@ -1962,7 +1962,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2115,7 +2115,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2543,7 +2543,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2690,7 +2690,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -2921,7 +2921,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. weight_column (str): Optional. Name of the column that should be used as the weight column. Higher values in this column give more importance to the row @@ -3036,7 +3036,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. weight_column (str): Optional. Name of the column that should be used as the weight column. Higher values in this column give more importance to the row @@ -3143,6 +3143,458 @@ def _model_upload_fail_string(self) -> str: ) +class AutoMLForecastingTrainingJob(_TrainingJob): + _supported_training_schemas = (schema.training_job.definition.automl_forecasting,) + + def __init__( + self, + display_name: str, + optimization_objective: Optional[str] = None, + column_transformations: Optional[Union[Dict, List[Dict]]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Constructs a AutoML Forecasting Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + optimization_objective (str): + Optional. Objective function the model is to be optimized towards. + The training process creates a Model that optimizes the value of the objective + function over the validation set. The supported optimization objectives: + "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE). + "minimize-mae" - Minimize mean-absolute error (MAE). + "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE). + "minimize-rmspe" - Minimize root-mean-squared percentage error (RMSPE). + "minimize-wape-mae" - Minimize the combination of weighted absolute percentage error (WAPE) + and mean-absolute-error (MAE). + "minimize-quantile-loss" - Minimize the quantile loss at the defined quantiles. + (Set this objective to build quantile forecasts.) + column_transformations (Optional[Union[Dict, List[Dict]]]): + Optional. Transformations to apply to the input columns (i.e. columns other + than the targetColumn). Each transformation may produce multiple + result values from the column's value, and all are used for training. + When creating transformation for BigQuery Struct column, the column + should be flattened using "." as the delimiter. + If an input column has no transformations on it, such a column is + ignored by the training, except for the targetColumn, which should have + no transformations defined on. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + ) + self._column_transformations = column_transformations + self._optimization_objective = optimization_objective + + def run( + self, + dataset: datasets.TimeSeriesDataset, + target_column: str, + time_column: str, + time_series_identifier_column: str, + unavailable_at_forecast_columns: List[str], + available_at_forecast_columns: List[str], + forecast_horizon: int, + data_granularity_unit: str, + data_granularity_count: int, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + time_series_attribute_columns: Optional[List[str]] = None, + context_window: Optional[int] = None, + export_evaluated_data_items: bool = False, + export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_override_destination: bool = False, + quantiles: Optional[List[float]] = None, + validation_options: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + The training data splits are set by default: Roughly 80% will be used for training, + 10% for validation, and 10% for test. + + Args: + dataset (datasets.Dataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For time series Datasets, all their data is exported to + training, to pick and choose from. + target_column (str): + Required. Name of the column that the Model is to predict values for. + time_column (str): + Required. Name of the column that identifies time order in the time series. + time_series_identifier_column (str): + Required. Name of the column that identifies the time series. + unavailable_at_forecast_columns (List[str]): + Required. Column names of columns that are unavailable at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is unknown before the forecast + (e.g. population of a city in a given year, or weather on a given day). + available_at_forecast_columns (List[str]): + Required. Column names of columns that are available at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is known at forecast. + forecast_horizon: (int): + Required. The amount of time into the future for which forecasted values for the target are + returned. Expressed in number of units defined by the [data_granularity_unit] and + [data_granularity_count] field. Inclusive. + data_granularity_unit (str): + Required. The data granularity unit. Accepted values are ``minute``, + ``hour``, ``day``, ``week``, ``month``, ``year``. + data_granularity_count (int): + Required. The number of data granularity units between data points in the training + data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other + values of [data_granularity_unit], must be 1. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``TRAIN``, + ``VALIDATE``, ``TEST``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + time_series_attribute_columns (List[str]): + Optional. Column names that should be used as attribute columns. + Each column is constant within a time series. + context_window (int): + Optional. The amount of time into the past training and prediction data is used for + model training and prediction respectively. Expressed in number of units defined by the + [data_granularity_unit] and [data_granularity_count] fields. When not provided uses the + default value of 0 which means the model sets each series context window to be 0 (also + known as "cold start"). Inclusive. + export_evaluated_data_items (bool): + Whether to export the test set predictions to a BigQuery table. + If False, then the export is not performed. + export_evaluated_data_items_bigquery_destination_uri (string): + Optional. URI of desired destination BigQuery table for exported test set predictions. + + Expected format: + ``bq://::`` + + If not specified, then results are exported to the following auto-created BigQuery + table: + ``:export_evaluated_examples__.evaluated_examples`` + + Applies only if [export_evaluated_data_items] is True. + export_evaluated_data_items_override_destination (bool): + Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri], + if the table exists, for exported test set predictions. If False, and the + table exists, then the training job will fail. + + Applies only if [export_evaluated_data_items] is True and + [export_evaluated_data_items_bigquery_destination_uri] is specified. + quantiles (List[float]): + Quantiles to use for the `minizmize-quantile-loss` + [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in + this case. + + Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive. + Each quantile must be unique. + validation_options (str): + Validation options for the data validation component. The available options are: + "fail-pipeline" - (default), will validate against the validation and fail the pipeline + if it fails. + "ignore-validation" - ignore the results of the validation and continue the pipeline + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError if Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError( + "AutoML Forecasting Training is already scheduled to run." + ) + + if self._has_run: + raise RuntimeError("AutoML Forecasting Training has already run.") + + return self._run( + dataset=dataset, + target_column=target_column, + time_column=time_column, + time_series_identifier_column=time_series_identifier_column, + unavailable_at_forecast_columns=unavailable_at_forecast_columns, + available_at_forecast_columns=available_at_forecast_columns, + forecast_horizon=forecast_horizon, + data_granularity_unit=data_granularity_unit, + data_granularity_count=data_granularity_count, + predefined_split_column_name=predefined_split_column_name, + weight_column=weight_column, + time_series_attribute_columns=time_series_attribute_columns, + context_window=context_window, + budget_milli_node_hours=budget_milli_node_hours, + export_evaluated_data_items=export_evaluated_data_items, + export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri, + export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination, + quantiles=quantiles, + validation_options=validation_options, + model_display_name=model_display_name, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.TimeSeriesDataset, + target_column: str, + time_column: str, + time_series_identifier_column: str, + unavailable_at_forecast_columns: List[str], + available_at_forecast_columns: List[str], + forecast_horizon: int, + data_granularity_unit: str, + data_granularity_count: int, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + time_series_attribute_columns: Optional[List[str]] = None, + context_window: Optional[int] = None, + export_evaluated_data_items: bool = False, + export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None, + export_evaluated_data_items_override_destination: bool = False, + quantiles: Optional[List[float]] = None, + validation_options: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + The training data splits are set by default: Roughly 80% will be used for training, + 10% for validation, and 10% for test. + + Args: + dataset (datasets.Dataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For time series Datasets, all their data is exported to + training, to pick and choose from. + target_column (str): + Required. Name of the column that the Model is to predict values for. + time_column (str): + Required. Name of the column that identifies time order in the time series. + time_series_identifier_column (str): + Required. Name of the column that identifies the time series. + unavailable_at_forecast_columns (List[str]): + Required. Column names of columns that are unavailable at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is unknown before the forecast + (e.g. population of a city in a given year, or weather on a given day). + available_at_forecast_columns (List[str]): + Required. Column names of columns that are available at forecast. + Each column contains information for the given entity (identified by the + [time_series_identifier_column]) that is known at forecast. + forecast_horizon: (int): + Required. The amount of time into the future for which forecasted values for the target are + returned. Expressed in number of units defined by the [data_granularity_unit] and + [data_granularity_count] field. Inclusive. + data_granularity_unit (str): + Required. The data granularity unit. Accepted values are ``minute``, + ``hour``, ``day``, ``week``, ``month``, ``year``. + data_granularity_count (int): + Required. The number of data granularity units between data points in the training + data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other + values of [data_granularity_unit], must be 1. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``TRAIN``, + ``VALIDATE``, ``TEST``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + time_series_attribute_columns (List[str]): + Optional. Column names that should be used as attribute columns. + Each column is constant within a time series. + context_window (int): + Optional. The number of periods offset into the past to restrict past sequence, where each + period is one unit of granularity as defined by [period]. When not provided uses the + default value of 0 which means the model sets each series historical window to be 0 (also + known as "cold start"). Inclusive. + export_evaluated_data_items (bool): + Whether to export the test set predictions to a BigQuery table. + If False, then the export is not performed. + export_evaluated_data_items_bigquery_destination_uri (string): + Optional. URI of desired destination BigQuery table for exported test set predictions. + + Expected format: + ``bq://::
`` + + If not specified, then results are exported to the following auto-created BigQuery + table: + ``:export_evaluated_examples__.evaluated_examples`` + + Applies only if [export_evaluated_data_items] is True. + export_evaluated_data_items_override_destination (bool): + Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri], + if the table exists, for exported test set predictions. If False, and the + table exists, then the training job will fail. + + Applies only if [export_evaluated_data_items] is True and + [export_evaluated_data_items_bigquery_destination_uri] is specified. + quantiles (List[float]): + Quantiles to use for the `minizmize-quantile-loss` + [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in + this case. + + Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive. + Each quantile must be unique. + validation_options (str): + Validation options for the data validation component. The available options are: + "fail-pipeline" - (default), will validate against the validation and fail the pipeline + if it fails. + "ignore-validation" - ignore the results of the validation and continue the pipeline + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + training_task_definition = schema.training_job.definition.automl_forecasting + + training_task_inputs_dict = { + # required inputs + "targetColumn": target_column, + "timeColumn": time_column, + "timeSeriesIdentifierColumn": time_series_identifier_column, + "timeSeriesAttributeColumns": time_series_attribute_columns, + "unavailableAtForecastColumns": unavailable_at_forecast_columns, + "availableAtForecastColumns": available_at_forecast_columns, + "forecastHorizon": forecast_horizon, + "dataGranularity": { + "unit": data_granularity_unit, + "quantity": data_granularity_count, + }, + "transformations": self._column_transformations, + "trainBudgetMilliNodeHours": budget_milli_node_hours, + # optional inputs + "weightColumn": weight_column, + "contextWindow": context_window, + "quantiles": quantiles, + "validationOptions": validation_options, + "optimizationObjective": self._optimization_objective, + } + + final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri + if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith( + "bq://" + ): + final_export_eval_bq_uri = f"bq://{final_export_eval_bq_uri}" + + if export_evaluated_data_items: + training_task_inputs_dict["exportEvaluatedDataItemsConfig"] = { + "destinationBigqueryUri": final_export_eval_bq_uri, + "overrideExistingTable": export_evaluated_data_items_override_destination, + } + + if model_display_name is None: + model_display_name = self._display_name + + model = gca_model.Model(display_name=model_display_name) + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=0.8, + validation_fraction_split=0.1, + test_fraction_split=0.1, + predefined_split_column_name=predefined_split_column_name, + model=model, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"Training Pipeline {self.resource_name} is not configured to upload a " + "Model." + ) + + class AutoMLImageTrainingJob(_TrainingJob): _supported_training_schemas = ( schema.training_job.definition.automl_image_classification, @@ -3893,7 +4345,7 @@ def run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -4022,7 +4474,7 @@ def _run( key is not present or has an invalid value, that piece is ignored by the pipeline. - Supported only for tabular Datasets. + Supported only for tabular and time series Datasets. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index 64f7a29671..ff86fc1cb8 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -36,6 +36,8 @@ model_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, + metadata_service_client_v1beta1, + tensorboard_service_client_v1beta1, ) from google.cloud.aiplatform.compat.services import ( dataset_service_client_v1, @@ -59,6 +61,7 @@ prediction_service_client_v1beta1.PredictionServiceClient, pipeline_service_client_v1beta1.PipelineServiceClient, job_service_client_v1beta1.JobServiceClient, + metadata_service_client_v1beta1.MetadataServiceClient, # v1 dataset_service_client_v1.DatasetServiceClient, endpoint_service_client_v1.EndpointServiceClient, @@ -68,12 +71,10 @@ job_service_client_v1.JobServiceClient, ) -# TODO(b/170334193): Add support for resource names with non-integer IDs -# TODO(b/170334098): Add support for resource names more than one level deep RESOURCE_NAME_PATTERN = re.compile( - r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P\w+)\/(?P\d+)$" + r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P[\w\-\/]+)\/(?P[\w-]+)$" ) -RESOURCE_ID_PATTERN = re.compile(r"^\d+$") +RESOURCE_ID_PATTERN = re.compile(r"^[\w-]+$") Fields = namedtuple("Fields", ["project", "location", "resource", "id"],) @@ -108,10 +109,12 @@ def extract_fields_from_resource_name( Required. A fully-qualified AI Platform (Unified) resource name resource_noun (str): - A plural resource noun to validate the resource name against. + A resource noun to validate the resource name against. For example, you would pass "datasets" to validate "projects/123/locations/us-central1/datasets/456". - + In the case of deeper naming structures, e.g., + "projects/123/locations/us-central1/metadataStores/123/contexts/456", + you would pass "metadataStores/123/contexts" as the resource_noun. Returns: fields (Fields): A named tuple containing four extracted fields from a resource name: @@ -141,9 +144,12 @@ def full_resource_name( Required. A fully-qualified AI Platform (Unified) resource name or resource ID. resource_noun (str): - A plural resource noun to validate the resource name against. + A resource noun to validate the resource name against. For example, you would pass "datasets" to validate "projects/123/locations/us-central1/datasets/456". + In the case of deeper naming structures, e.g., + "projects/123/locations/us-central1/metadataStores/123/contexts/456", + you would pass "metadataStores/123/contexts" as the resource_noun. project (str): Optional project to retrieve resource_noun from. If not set, project set in aiplatform.init will be used. @@ -160,7 +166,8 @@ def full_resource_name( If resource name, resource ID or project ID not provided. """ validate_resource_noun(resource_noun) - # Fully qualified resource name, i.e. "projects/.../locations/.../datasets/12345" + # Fully qualified resource name, e.g., "projects/.../locations/.../datasets/12345" or + # "projects/.../locations/.../metadataStores/.../contexts/12345" valid_name = extract_fields_from_resource_name( resource_name=resource_name, resource_noun=resource_noun ) @@ -457,6 +464,22 @@ class PredictionClientWithOverride(ClientWithOverride): ) +class MetadataClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.V1BETA1 + _version_map = ( + (compat.V1BETA1, metadata_service_client_v1beta1.MetadataServiceClient), + ) + + +class TensorboardClientWithOverride(ClientWithOverride): + _is_temporary = False + _default_version = compat.V1BETA1 + _version_map = ( + (compat.V1BETA1, tensorboard_service_client_v1beta1.TensorboardServiceClient), + ) + + AiPlatformServiceClientWithOverride = TypeVar( "AiPlatformServiceClientWithOverride", DatasetClientWithOverride, @@ -465,6 +488,8 @@ class PredictionClientWithOverride(ClientWithOverride): ModelClientWithOverride, PipelineClientWithOverride, PredictionClientWithOverride, + MetadataClientWithOverride, + TensorboardClientWithOverride, ) diff --git a/noxfile.py b/noxfile.py index 4ea506a2a2..cd85c2b17e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -88,7 +88,7 @@ def default(session): session.install("mock", "pytest", "pytest-cov", "-c", constraints_path) - session.install("-e", ".", "-c", constraints_path) + session.install("-e", ".[testing]", "-c", constraints_path) # Run py.test against the unit tests. session.run( @@ -101,7 +101,7 @@ def default(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", - os.path.join("tests", "unit", "gapic"), + os.path.join("tests", "unit"), *session.posargs, ) diff --git a/setup.py b/setup.py index f3f94568b1..a0a1a29bf2 100644 --- a/setup.py +++ b/setup.py @@ -29,12 +29,27 @@ with io.open(readme_filename, encoding="utf-8") as readme_file: readme = readme_file.read() +tensorboard_extra_require = [ + "tensorflow-cpu >= 2.3.0, <=2.5.0rc", + "grpcio~=1.34.0", + "six~=1.15.0", +] +metadata_extra_require = ["pandas >= 1.0.0"] +full_extra_require = tensorboard_extra_require + metadata_extra_require +testing_extra_require = full_extra_require + ["grpcio-testing ~= 1.34.0"] + + setuptools.setup( name=name, version=version, description=description, long_description=readme, packages=setuptools.PEP420PackageFinder.find(), + entry_points={ + "console_scripts": [ + "tb-gcp-uploader=google.cloud.aiplatform.tensorboard.uploader_main:run_main" + ], + }, namespace_packages=("google", "google.cloud"), author="Google LLC", author_email="googleapis-packages@google.com", @@ -48,6 +63,12 @@ "google-cloud-storage >= 1.32.0, < 2.0.0dev", "google-cloud-bigquery >= 1.15.0, < 3.0.0dev", ), + extras_require={ + "full": full_extra_require, + "metadata": metadata_extra_require, + "tensorboard": tensorboard_extra_require, + "testing": testing_extra_require, + }, python_requires=">=3.6", scripts=[], classifiers=[ diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py new file mode 100644 index 0000000000..5d89360566 --- /dev/null +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -0,0 +1,488 @@ +import importlib +import pytest +from unittest import mock + +from google.cloud import aiplatform +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform.training_jobs import AutoMLForecastingTrainingJob + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +_TEST_BUCKET_NAME = "test-bucket" +_TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder" +_TEST_GCS_PATH = f"{_TEST_BUCKET_NAME}/{_TEST_GCS_PATH_WITHOUT_BUCKET}" +_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/" +_TEST_PROJECT = "test-project" + +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" +_TEST_METADATA_SCHEMA_URI_TIMESERIES = schema.dataset.metadata.time_series +_TEST_METADATA_SCHEMA_URI_NONTIMESERIES = schema.dataset.metadata.image + +_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [ + {"auto": {"column_name": "time"}}, + {"auto": {"column_name": "time_series_identifier"}}, + {"auto": {"column_name": "target"}}, + {"auto": {"column_name": "weight"}}, +] +_TEST_TRAINING_TARGET_COLUMN = "target" +_TEST_TRAINING_TIME_COLUMN = "time" +_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN = "time_series_identifier" +_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS = [] +_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS = [] +_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS = [] +_TEST_TRAINING_FORECAST_HORIZON = 10 +_TEST_TRAINING_DATA_GRANULARITY_UNIT = "day" +_TEST_TRAINING_DATA_GRANULARITY_COUNT = 1 +_TEST_TRAINING_CONTEXT_WINDOW = None +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS = True +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI = ( + "bq://path.to.table" +) +_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION = False +_TEST_TRAINING_QUANTILES = None +_TEST_TRAINING_VALIDATION_OPTIONS = None +_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS = 1000 +_TEST_TRAINING_WEIGHT_COLUMN = "weight" +_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME = "minimize-rmse" +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + { + # required inputs + "targetColumn": _TEST_TRAINING_TARGET_COLUMN, + "timeColumn": _TEST_TRAINING_TIME_COLUMN, + "timeSeriesIdentifierColumn": _TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + "timeSeriesAttributeColumns": _TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + "unavailableAtForecastColumns": _TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + "availableAtForecastColumns": _TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + "forecastHorizon": _TEST_TRAINING_FORECAST_HORIZON, + "dataGranularity": { + "unit": _TEST_TRAINING_DATA_GRANULARITY_UNIT, + "quantity": _TEST_TRAINING_DATA_GRANULARITY_COUNT, + }, + "transformations": _TEST_TRAINING_COLUMN_TRANSFORMATIONS, + "trainBudgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + # optional inputs + "weightColumn": _TEST_TRAINING_WEIGHT_COLUMN, + "contextWindow": _TEST_TRAINING_CONTEXT_WINDOW, + "exportEvaluatedDataItemsConfig": { + "destinationBigqueryUri": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + "overrideExistingTable": _TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + }, + "quantiles": _TEST_TRAINING_QUANTILES, + "validationOptions": _TEST_TRAINING_VALIDATION_OPTIONS, + "optimizationObjective": _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + }, + struct_pb2.Value(), +) + +_TEST_DATASET_NAME = "test-dataset-name" + +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_TRAINING_FRACTION_SPLIT = 0.8 +_TEST_VALIDATION_FRACTION_SPLIT = 0.1 +_TEST_TEST_FRACTION_SPLIT = 0.1 +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" + +_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" + +_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345" + +_TEST_PIPELINE_RESOURCE_NAME = ( + "projects/my-project/locations/us-central1/trainingPipeline/12345" +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_time_series(): + ds = mock.MagicMock(datasets.TimeSeriesDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TIMESERIES, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_dataset_nontimeseries(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTIMESERIES, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +class TestAutoMLForecastingTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_time_series, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_forecasting, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_time_series, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_forecasting, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_dataset_time_series, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_dataset_time_series, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 91d86409f9..7e65c99b4c 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -19,11 +19,13 @@ import os import pytest from unittest import mock +from unittest.mock import patch import google.auth from google.auth import credentials from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata.metadata import metadata_service from google.cloud.aiplatform import constants from google.cloud.aiplatform import utils @@ -37,6 +39,7 @@ _TEST_LOCATION_2 = "europe-west4" _TEST_INVALID_LOCATION = "test-invalid-location" _TEST_EXPERIMENT = "test-experiment" +_TEST_DESCRIPTION = "test-description" _TEST_STAGING_BUCKET = "test-bucket" @@ -69,9 +72,27 @@ def test_init_location_with_invalid_location_raises(self): with pytest.raises(ValueError): initializer.global_config.init(location=_TEST_INVALID_LOCATION) - def test_init_experiment_sets_experiment(self): + @patch.object(metadata_service, "set_experiment") + def test_init_experiment_sets_experiment(self, set_experiment_mock): initializer.global_config.init(experiment=_TEST_EXPERIMENT) - assert initializer.global_config.experiment == _TEST_EXPERIMENT + set_experiment_mock.assert_called_once_with( + experiment=_TEST_EXPERIMENT, description=None + ) + + @patch.object(metadata_service, "set_experiment") + def test_init_experiment_sets_experiment_with_description( + self, set_experiment_mock + ): + initializer.global_config.init( + experiment=_TEST_EXPERIMENT, experiment_description=_TEST_DESCRIPTION + ) + set_experiment_mock.assert_called_once_with( + experiment=_TEST_EXPERIMENT, description=_TEST_DESCRIPTION + ) + + def test_init_experiment_description_fail_without_experiment(self): + with pytest.raises(ValueError): + initializer.global_config.init(experiment_description=_TEST_DESCRIPTION) def test_init_staging_bucket_sets_staging_bucket(self): initializer.global_config.init(staging_bucket=_TEST_STAGING_BUCKET) diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py new file mode 100644 index 0000000000..9a930dd3f5 --- /dev/null +++ b/tests/unit/aiplatform/test_metadata.py @@ -0,0 +1,661 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +from unittest.mock import patch, call + +import pytest +from google.api_core import exceptions + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata import metadata +from google.cloud.aiplatform_v1beta1 import ( + AddContextArtifactsAndExecutionsResponse, + Event, + LineageSubgraph, + ListExecutionsRequest, +) +from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact +from google.cloud.aiplatform_v1beta1 import Context as GapicContext +from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution +from google.cloud.aiplatform_v1beta1 import ( + MetadataServiceClient, + AddExecutionEventsResponse, +) +from google.cloud.aiplatform_v1beta1 import MetadataStore as GapicMetadataStore + +# project + +_TEST_PROJECT = "test-project" +_TEST_OTHER_PROJECT = "test-project-1" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" +) +_TEST_EXPERIMENT = "test-experiment" +_TEST_EXPERIMENT_DESCRIPTION = "test-experiment-description" +_TEST_OTHER_EXPERIMENT_DESCRIPTION = "test-other-experiment-description" +_TEST_PIPELINE = _TEST_EXPERIMENT +_TEST_RUN = "run-1" +_TEST_OTHER_RUN = "run-2" + +# resource attributes +_TEST_METADATA = {"test-param1": 1, "test-param2": "test-value", "test-param3": True} + +# metadataStore +_TEST_METADATASTORE = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" +) + +# context +_TEST_CONTEXT_ID = _TEST_EXPERIMENT +_TEST_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_CONTEXT_ID}" + +# execution +_TEST_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}" +_TEST_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_EXECUTION_ID}" +_TEST_OTHER_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_OTHER_RUN}" +_TEST_OTHER_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_OTHER_EXECUTION_ID}" + +# artifact +_TEST_ARTIFACT_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}-metrics" +_TEST_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_ARTIFACT_ID}" +_TEST_OTHER_ARTIFACT_ID = f"{_TEST_EXPERIMENT}-{_TEST_OTHER_RUN}-metrics" +_TEST_OTHER_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_OTHER_ARTIFACT_ID}" + +# parameters +_TEST_PARAM_KEY_1 = "learning_rate" +_TEST_PARAM_KEY_2 = "dropout" +_TEST_PARAMS = {_TEST_PARAM_KEY_1: 0.01, _TEST_PARAM_KEY_2: 0.2} +_TEST_OTHER_PARAMS = {_TEST_PARAM_KEY_1: 0.02, _TEST_PARAM_KEY_2: 0.3} + +# metrics +_TEST_METRIC_KEY_1 = "rmse" +_TEST_METRIC_KEY_2 = "accuracy" +_TEST_METRICS = {_TEST_METRIC_KEY_1: 222, _TEST_METRIC_KEY_2: 1} +_TEST_OTHER_METRICS = {_TEST_METRIC_KEY_2: 0.9} + +# schema +_TEST_WRONG_SCHEMA_TITLE = "system.WrongSchema" + + +@pytest.fixture +def get_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + name=_TEST_METADATASTORE, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def get_context_mock(): + with patch.object(MetadataServiceClient, "get_context") as get_context_mock: + get_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + description=_TEST_EXPERIMENT_DESCRIPTION, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + yield get_context_mock + + +@pytest.fixture +def get_context_wrong_schema_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_context_wrong_schema_mock: + get_context_wrong_schema_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + schema_title=_TEST_WRONG_SCHEMA_TITLE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + yield get_context_wrong_schema_mock + + +@pytest.fixture +def get_pipeline_context_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_pipeline_context_mock: + get_pipeline_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + schema_title=constants.SYSTEM_PIPELINE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_PIPELINE], + metadata=constants.EXPERIMENT_METADATA, + ) + yield get_pipeline_context_mock + + +@pytest.fixture +def get_context_not_found_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_context_not_found_mock: + get_context_not_found_mock.side_effect = exceptions.NotFound("test: not found") + yield get_context_not_found_mock + + +@pytest.fixture +def update_context_mock(): + with patch.object(MetadataServiceClient, "update_context") as update_context_mock: + update_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + description=_TEST_OTHER_EXPERIMENT_DESCRIPTION, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + yield update_context_mock + + +@pytest.fixture +def add_context_artifacts_and_executions_mock(): + with patch.object( + MetadataServiceClient, "add_context_artifacts_and_executions" + ) as add_context_artifacts_and_executions_mock: + add_context_artifacts_and_executions_mock.return_value = ( + AddContextArtifactsAndExecutionsResponse() + ) + yield add_context_artifacts_and_executions_mock + + +@pytest.fixture +def get_execution_mock(): + with patch.object(MetadataServiceClient, "get_execution") as get_execution_mock: + get_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + yield get_execution_mock + + +@pytest.fixture +def get_execution_wrong_schema_mock(): + with patch.object( + MetadataServiceClient, "get_execution" + ) as get_execution_wrong_schema_mock: + get_execution_wrong_schema_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=_TEST_WRONG_SCHEMA_TITLE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + ) + yield get_execution_wrong_schema_mock + + +@pytest.fixture +def update_execution_mock(): + with patch.object( + MetadataServiceClient, "update_execution" + ) as update_execution_mock: + update_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_PARAMS, + ) + yield update_execution_mock + + +@pytest.fixture +def add_execution_events_mock(): + with patch.object( + MetadataServiceClient, "add_execution_events" + ) as add_execution_events_mock: + add_execution_events_mock.return_value = AddExecutionEventsResponse() + yield add_execution_events_mock + + +@pytest.fixture +def list_executions_mock(): + with patch.object(MetadataServiceClient, "list_executions") as list_executions_mock: + list_executions_mock.return_value = [ + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_PARAMS, + ), + GapicExecution( + name=_TEST_OTHER_EXECUTION_NAME, + display_name=_TEST_OTHER_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_OTHER_PARAMS, + ), + ] + yield list_executions_mock + + +@pytest.fixture +def query_execution_inputs_and_outputs_mock(): + with patch.object( + MetadataServiceClient, "query_execution_inputs_and_outputs" + ) as query_execution_inputs_and_outputs_mock: + query_execution_inputs_and_outputs_mock.side_effect = [ + LineageSubgraph( + artifacts=[ + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[ + constants.SYSTEM_METRICS + ], + metadata=_TEST_METRICS, + ), + ], + ), + LineageSubgraph( + artifacts=[ + GapicArtifact( + name=_TEST_OTHER_ARTIFACT_NAME, + display_name=_TEST_OTHER_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[ + constants.SYSTEM_METRICS + ], + metadata=_TEST_OTHER_METRICS, + ), + ], + ), + ] + yield query_execution_inputs_and_outputs_mock + + +@pytest.fixture +def get_artifact_mock(): + with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock: + get_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + yield get_artifact_mock + + +@pytest.fixture +def get_artifact_wrong_schema_mock(): + with patch.object( + MetadataServiceClient, "get_artifact" + ) as get_artifact_wrong_schema_mock: + get_artifact_wrong_schema_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=_TEST_WRONG_SCHEMA_TITLE, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + ) + yield get_artifact_wrong_schema_mock + + +@pytest.fixture +def update_artifact_mock(): + with patch.object(MetadataServiceClient, "update_artifact") as update_artifact_mock: + update_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + metadata=_TEST_METRICS, + ) + yield update_artifact_mock + + +def _assert_frame_equal_with_sorted_columns(dataframe_1, dataframe_2): + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to test the get_experiment_df/pipeline_df method. " + 'Please install the SDK using "pip install python-aiplatform[full]"' + ) + + pd.testing.assert_frame_equal( + dataframe_1.sort_index(axis=1), dataframe_2.sort_index(axis=1), check_names=True + ) + + +class TestMetadata: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_experiment_with_existing_metadataStore_and_context( + self, get_metadata_store_mock, get_context_mock + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + + get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + def test_init_experiment_with_existing_description( + self, get_metadata_store_mock, get_context_mock + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + experiment_description=_TEST_EXPERIMENT_DESCRIPTION, + ) + + get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + def test_init_experiment_without_existing_description(self, update_context_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + experiment_description=_TEST_OTHER_EXPERIMENT_DESCRIPTION, + ) + + experiment_context = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + description=_TEST_OTHER_EXPERIMENT_DESCRIPTION, + schema_title=constants.SYSTEM_EXPERIMENT, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], + metadata=constants.EXPERIMENT_METADATA, + ) + + update_context_mock.assert_called_once_with(context=experiment_context) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_wrong_schema_mock") + def test_init_experiment_wrong_schema(self): + with pytest.raises(ValueError): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + ) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_init_experiment_reset(self): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.start_run(_TEST_RUN) + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + assert metadata.metadata_service.experiment_name == _TEST_EXPERIMENT + assert metadata.metadata_service.run_name == _TEST_RUN + + aiplatform.init(project=_TEST_OTHER_PROJECT, location=_TEST_LOCATION) + + assert metadata.metadata_service.experiment_name is None + assert metadata.metadata_service.run_name is None + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + def test_start_run_with_existing_execution_and_artifact( + self, + get_execution_mock, + add_context_artifacts_and_executions_mock, + get_artifact_mock, + add_execution_events_mock, + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.start_run(_TEST_RUN) + + get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=None, + executions=[_TEST_EXECUTION_NAME], + ) + get_artifact_mock.assert_called_once_with(name=_TEST_ARTIFACT_NAME) + add_execution_events_mock.assert_called_once_with( + execution=_TEST_EXECUTION_NAME, + events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], + ) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_wrong_schema_mock") + def test_start_run_with_wrong_run_execution_schema(self,): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + with pytest.raises(ValueError): + aiplatform.start_run(_TEST_RUN) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_wrong_schema_mock") + def test_start_run_with_wrong_metrics_artifact_schema(self,): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + with pytest.raises(ValueError): + aiplatform.start_run(_TEST_RUN) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_log_params( + self, update_execution_mock, + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.start_run(_TEST_RUN) + aiplatform.log_params(_TEST_PARAMS) + + updated_execution = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_PARAMS, + ) + + update_execution_mock.assert_called_once_with(execution=updated_execution) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_log_metrics( + self, update_artifact_mock, + ): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.start_run(_TEST_RUN) + aiplatform.log_metrics(_TEST_METRICS) + + updated_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_ARTIFACT_ID, + schema_title=constants.SYSTEM_METRICS, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS], + metadata=_TEST_METRICS, + ) + + update_artifact_mock.assert_called_once_with(artifact=updated_artifact) + + @pytest.mark.usefixtures("get_metadata_store_mock") + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("add_context_artifacts_and_executions_mock") + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("add_execution_events_mock") + def test_log_metrics_string_value_raise_error(self): + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, experiment=_TEST_EXPERIMENT + ) + aiplatform.start_run(_TEST_RUN) + with pytest.raises(TypeError): + aiplatform.log_metrics({"test": "string"}) + + @pytest.mark.usefixtures("get_context_mock") + def test_get_experiment_df( + self, list_executions_mock, query_execution_inputs_and_outputs_mock + ): + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to test the get_experiment_df method. " + 'Please install the SDK using "pip install python-aiplatform[full]"' + ) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + experiment_df = aiplatform.get_experiment_df(_TEST_EXPERIMENT) + + expected_filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{_TEST_CONTEXT_NAME}")' + list_executions_mock.assert_called_once_with( + request=ListExecutionsRequest(parent=_TEST_PARENT, filter=expected_filter,) + ) + query_execution_inputs_and_outputs_mock.assert_has_calls( + [ + call(execution=_TEST_EXECUTION_NAME), + call(execution=_TEST_OTHER_EXECUTION_NAME), + ] + ) + experiment_df_truth = pd.DataFrame( + [ + { + "experiment_name": _TEST_EXPERIMENT, + "run_name": _TEST_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.01, + "param.%s" % _TEST_PARAM_KEY_2: 0.2, + "metric.%s" % _TEST_METRIC_KEY_1: 222, + "metric.%s" % _TEST_METRIC_KEY_2: 1, + }, + { + "experiment_name": _TEST_EXPERIMENT, + "run_name": _TEST_OTHER_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.02, + "param.%s" % _TEST_PARAM_KEY_2: 0.3, + "metric.%s" % _TEST_METRIC_KEY_2: 0.9, + }, + ] + ) + + _assert_frame_equal_with_sorted_columns(experiment_df, experiment_df_truth) + + @pytest.mark.usefixtures("get_context_not_found_mock") + def test_get_experiment_df_not_exist(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(exceptions.NotFound): + aiplatform.get_experiment_df(_TEST_EXPERIMENT) + + @pytest.mark.usefixtures("get_pipeline_context_mock") + def test_get_experiment_df_wrong_schema(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(ValueError): + aiplatform.get_experiment_df(_TEST_EXPERIMENT) + + @pytest.mark.usefixtures("get_pipeline_context_mock") + def test_get_pipeline_df( + self, list_executions_mock, query_execution_inputs_and_outputs_mock + ): + try: + import pandas as pd + except ImportError: + raise ImportError( + "Pandas is not installed and is required to test the get_pipeline_df method. " + 'Please install the SDK using "pip install python-aiplatform[full]"' + ) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + pipeline_df = aiplatform.get_pipeline_df(_TEST_PIPELINE) + + expected_filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{_TEST_CONTEXT_NAME}")' + list_executions_mock.assert_called_once_with( + request=ListExecutionsRequest(parent=_TEST_PARENT, filter=expected_filter,) + ) + query_execution_inputs_and_outputs_mock.assert_has_calls( + [ + call(execution=_TEST_EXECUTION_NAME), + call(execution=_TEST_OTHER_EXECUTION_NAME), + ] + ) + pipeline_df_truth = pd.DataFrame( + [ + { + "pipeline_name": _TEST_PIPELINE, + "run_name": _TEST_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.01, + "param.%s" % _TEST_PARAM_KEY_2: 0.2, + "metric.%s" % _TEST_METRIC_KEY_1: 222, + "metric.%s" % _TEST_METRIC_KEY_2: 1, + }, + { + "pipeline_name": _TEST_PIPELINE, + "run_name": _TEST_OTHER_RUN, + "param.%s" % _TEST_PARAM_KEY_1: 0.02, + "param.%s" % _TEST_PARAM_KEY_2: 0.3, + "metric.%s" % _TEST_METRIC_KEY_2: 0.9, + }, + ] + ) + + _assert_frame_equal_with_sorted_columns(pipeline_df, pipeline_df_truth) + + @pytest.mark.usefixtures("get_context_not_found_mock") + def test_get_pipeline_df_not_exist(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(exceptions.NotFound): + aiplatform.get_pipeline_df(_TEST_PIPELINE) + + @pytest.mark.usefixtures("get_context_mock") + def test_get_pipeline_df_wrong_schema(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(ValueError): + aiplatform.get_pipeline_df(_TEST_PIPELINE) diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py new file mode 100644 index 0000000000..19258aef3c --- /dev/null +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -0,0 +1,801 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +from unittest.mock import patch + +import pytest +from google.api_core import exceptions + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata import artifact +from google.cloud.aiplatform.metadata import context +from google.cloud.aiplatform.metadata import execution +from google.cloud.aiplatform_v1beta1 import AddContextArtifactsAndExecutionsResponse +from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact +from google.cloud.aiplatform_v1beta1 import Context as GapicContext +from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution +from google.cloud.aiplatform_v1beta1 import LineageSubgraph +from google.cloud.aiplatform_v1beta1 import ( + MetadataServiceClient, + AddExecutionEventsResponse, + Event, + ListExecutionsRequest, + ListArtifactsRequest, + ListContextsRequest, +) + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_METADATA_STORE = "test-metadata-store" +_TEST_ALT_LOCATION = "europe-west4" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/{_TEST_METADATA_STORE}" + +# resource attributes +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_SCHEMA_TITLE = "test.Example" +_TEST_SCHEMA_VERSION = "0.0.1" +_TEST_DESCRIPTION = "test description" +_TEST_METADATA = {"test-param1": 1, "test-param2": "test-value", "test-param3": True} +_TEST_UPDATED_METADATA = { + "test-param1": 2, + "test-param2": "test-value-1", + "test-param3": False, +} + +# context +_TEST_CONTEXT_ID = "test-context-id" +_TEST_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_CONTEXT_ID}" + +# artifact +_TEST_ARTIFACT_ID = "test-artifact-id" +_TEST_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_ARTIFACT_ID}" + +# execution +_TEST_EXECUTION_ID = "test-execution-id" +_TEST_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_EXECUTION_ID}" + + +@pytest.fixture +def get_context_mock(): + with patch.object(MetadataServiceClient, "get_context") as get_context_mock: + get_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield get_context_mock + + +@pytest.fixture +def get_context_for_get_or_create_mock(): + with patch.object( + MetadataServiceClient, "get_context" + ) as get_context_for_get_or_create_mock: + get_context_for_get_or_create_mock.side_effect = exceptions.NotFound( + "test: Context Not Found" + ) + yield get_context_for_get_or_create_mock + + +@pytest.fixture +def create_context_mock(): + with patch.object(MetadataServiceClient, "create_context") as create_context_mock: + create_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield create_context_mock + + +@pytest.fixture +def list_contexts_mock(): + with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock: + list_contexts_mock.return_value = [ + GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield list_contexts_mock + + +@pytest.fixture +def update_context_mock(): + with patch.object(MetadataServiceClient, "update_context") as update_context_mock: + update_context_mock.return_value = GapicContext( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + yield update_context_mock + + +@pytest.fixture +def add_context_artifacts_and_executions_mock(): + with patch.object( + MetadataServiceClient, "add_context_artifacts_and_executions" + ) as add_context_artifacts_and_executions_mock: + add_context_artifacts_and_executions_mock.return_value = ( + AddContextArtifactsAndExecutionsResponse() + ) + yield add_context_artifacts_and_executions_mock + + +@pytest.fixture +def get_execution_mock(): + with patch.object(MetadataServiceClient, "get_execution") as get_execution_mock: + get_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield get_execution_mock + + +@pytest.fixture +def get_execution_for_get_or_create_mock(): + with patch.object( + MetadataServiceClient, "get_execution" + ) as get_execution_for_get_or_create_mock: + get_execution_for_get_or_create_mock.side_effect = exceptions.NotFound( + "test: Execution Not Found" + ) + yield get_execution_for_get_or_create_mock + + +@pytest.fixture +def create_execution_mock(): + with patch.object( + MetadataServiceClient, "create_execution" + ) as create_execution_mock: + create_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield create_execution_mock + + +@pytest.fixture +def list_executions_mock(): + with patch.object(MetadataServiceClient, "list_executions") as list_executions_mock: + list_executions_mock.return_value = [ + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield list_executions_mock + + +@pytest.fixture +def query_execution_inputs_and_outputs_mock(): + with patch.object( + MetadataServiceClient, "query_execution_inputs_and_outputs" + ) as query_execution_inputs_and_outputs_mock: + query_execution_inputs_and_outputs_mock.return_value = LineageSubgraph( + artifacts=[ + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ], + ) + yield query_execution_inputs_and_outputs_mock + + +@pytest.fixture +def update_execution_mock(): + with patch.object( + MetadataServiceClient, "update_execution" + ) as update_execution_mock: + update_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + yield update_execution_mock + + +@pytest.fixture +def add_execution_events_mock(): + with patch.object( + MetadataServiceClient, "add_execution_events" + ) as add_execution_events_mock: + add_execution_events_mock.return_value = AddExecutionEventsResponse() + yield add_execution_events_mock + + +@pytest.fixture +def get_artifact_mock(): + with patch.object(MetadataServiceClient, "get_artifact") as get_artifact_mock: + get_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield get_artifact_mock + + +@pytest.fixture +def get_artifact_for_get_or_create_mock(): + with patch.object( + MetadataServiceClient, "get_artifact" + ) as get_artifact_for_get_or_create_mock: + get_artifact_for_get_or_create_mock.side_effect = exceptions.NotFound( + "test: Artifact Not Found" + ) + yield get_artifact_for_get_or_create_mock + + +@pytest.fixture +def create_artifact_mock(): + with patch.object(MetadataServiceClient, "create_artifact") as create_artifact_mock: + create_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield create_artifact_mock + + +@pytest.fixture +def list_artifacts_mock(): + with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock: + list_artifacts_mock.return_value = [ + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ), + ] + yield list_artifacts_mock + + +@pytest.fixture +def update_artifact_mock(): + with patch.object(MetadataServiceClient, "update_artifact") as update_artifact_mock: + update_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + yield update_artifact_mock + + +class TestContext: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_context(self, get_context_mock): + aiplatform.init(project=_TEST_PROJECT) + context._Context(resource_name=_TEST_CONTEXT_NAME) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + def test_init_context_with_id(self, get_context_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + context._Context( + resource_name=_TEST_CONTEXT_ID, metadata_store_id=_TEST_METADATA_STORE + ) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + def test_get_or_create_context( + self, get_context_for_get_or_create_mock, create_context_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context._Context.get_or_create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_context = GapicContext( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + get_context_for_get_or_create_mock.assert_called_once_with( + name=_TEST_CONTEXT_NAME + ) + create_context_mock.assert_called_once_with( + parent=_TEST_PARENT, context_id=_TEST_CONTEXT_ID, context=expected_context, + ) + + expected_context.name = _TEST_CONTEXT_NAME + assert my_context._gca_resource == expected_context + + @pytest.mark.usefixtures("get_context_mock") + @pytest.mark.usefixtures("create_context_mock") + def test_update_context(self, update_context_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_context = context._Context._create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_context.update(_TEST_UPDATED_METADATA) + + updated_context = GapicContext( + name=_TEST_CONTEXT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + + update_context_mock.assert_called_once_with(context=updated_context) + assert my_context._gca_resource == updated_context + + @pytest.mark.usefixtures("get_context_mock") + def test_list_contexts(self, list_contexts_mock): + aiplatform.init(project=_TEST_PROJECT) + + filter = "test-filter" + context_list = context._Context.list( + filter=filter, metadata_store_id=_TEST_METADATA_STORE + ) + + expected_context = GapicContext( + name=_TEST_CONTEXT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + list_contexts_mock.assert_called_once_with( + request=ListContextsRequest(parent=_TEST_PARENT, filter=filter,) + ) + assert len(context_list) == 2 + assert context_list[0]._gca_resource == expected_context + assert context_list[1]._gca_resource == expected_context + + @pytest.mark.usefixtures("get_context_mock") + def test_add_artifacts_and_executions( + self, add_context_artifacts_and_executions_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_context = context._Context.get_or_create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + my_context.add_artifacts_and_executions( + artifact_resource_names=[_TEST_ARTIFACT_NAME], + execution_resource_names=[_TEST_EXECUTION_NAME], + ) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=[_TEST_ARTIFACT_NAME], + executions=[_TEST_EXECUTION_NAME], + ) + + @pytest.mark.usefixtures("get_context_mock") + def test_add_artifacts_only(self, add_context_artifacts_and_executions_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_context = context._Context.get_or_create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + my_context.add_artifacts_and_executions( + artifact_resource_names=[_TEST_ARTIFACT_NAME] + ) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=[_TEST_ARTIFACT_NAME], + executions=None, + ) + + @pytest.mark.usefixtures("get_context_mock") + def test_add_executions_only(self, add_context_artifacts_and_executions_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_context = context._Context.get_or_create( + resource_id=_TEST_CONTEXT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + my_context.add_artifacts_and_executions( + execution_resource_names=[_TEST_EXECUTION_NAME] + ) + add_context_artifacts_and_executions_mock.assert_called_once_with( + context=_TEST_CONTEXT_NAME, + artifacts=None, + executions=[_TEST_EXECUTION_NAME], + ) + + +class TestExecution: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_execution(self, get_execution_mock): + aiplatform.init(project=_TEST_PROJECT) + execution._Execution(resource_name=_TEST_EXECUTION_NAME) + get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) + + def test_init_execution_with_id(self, get_execution_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + execution._Execution( + resource_name=_TEST_EXECUTION_ID, metadata_store_id=_TEST_METADATA_STORE + ) + get_execution_mock.assert_called_once_with(name=_TEST_EXECUTION_NAME) + + def test_get_or_create_execution( + self, get_execution_for_get_or_create_mock, create_execution_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_execution = execution._Execution.get_or_create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_execution = GapicExecution( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + get_execution_for_get_or_create_mock.assert_called_once_with( + name=_TEST_EXECUTION_NAME + ) + create_execution_mock.assert_called_once_with( + parent=_TEST_PARENT, + execution_id=_TEST_EXECUTION_ID, + execution=expected_execution, + ) + + expected_execution.name = _TEST_EXECUTION_NAME + assert my_execution._gca_resource == expected_execution + + @pytest.mark.usefixtures("get_execution_mock") + @pytest.mark.usefixtures("create_execution_mock") + def test_update_execution(self, update_execution_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_execution = execution._Execution._create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_execution.update(_TEST_UPDATED_METADATA) + + updated_execution = GapicExecution( + name=_TEST_EXECUTION_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + + update_execution_mock.assert_called_once_with(execution=updated_execution) + assert my_execution._gca_resource == updated_execution + + @pytest.mark.usefixtures("get_execution_mock") + def test_list_executions(self, list_executions_mock): + aiplatform.init(project=_TEST_PROJECT) + + filter = "test-filter" + execution_list = execution._Execution.list( + filter=filter, metadata_store_id=_TEST_METADATA_STORE + ) + + expected_execution = GapicExecution( + name=_TEST_EXECUTION_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + list_executions_mock.assert_called_once_with( + request=ListExecutionsRequest(parent=_TEST_PARENT, filter=filter,) + ) + assert len(execution_list) == 2 + assert execution_list[0]._gca_resource == expected_execution + assert execution_list[1]._gca_resource == expected_execution + + @pytest.mark.usefixtures("get_execution_mock") + def test_add_artifact(self, add_execution_events_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_execution = execution._Execution.get_or_create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_execution.add_artifact( + artifact_resource_name=_TEST_ARTIFACT_NAME, input=False, + ) + add_execution_events_mock.assert_called_once_with( + execution=_TEST_EXECUTION_NAME, + events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)], + ) + + @pytest.mark.usefixtures("get_execution_mock") + def test_query_input_and_output_artifacts( + self, query_execution_inputs_and_outputs_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_execution = execution._Execution.get_or_create( + resource_id=_TEST_EXECUTION_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + artifact_list = my_execution.query_input_and_output_artifacts() + + expected_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + query_execution_inputs_and_outputs_mock.assert_called_once_with( + execution=_TEST_EXECUTION_NAME, + ) + assert len(artifact_list) == 2 + assert artifact_list[0]._gca_resource == expected_artifact + assert artifact_list[1]._gca_resource == expected_artifact + + +class TestArtifact: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_artifact(self, get_artifact_mock): + aiplatform.init(project=_TEST_PROJECT) + artifact._Artifact(resource_name=_TEST_ARTIFACT_NAME) + get_artifact_mock.assert_called_once_with(name=_TEST_ARTIFACT_NAME) + + def test_init_artifact_with_id(self, get_artifact_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + artifact._Artifact( + resource_name=_TEST_ARTIFACT_ID, metadata_store_id=_TEST_METADATA_STORE + ) + get_artifact_mock.assert_called_once_with(name=_TEST_ARTIFACT_NAME) + + def test_get_or_create_artifact( + self, get_artifact_for_get_or_create_mock, create_artifact_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_artifact = artifact._Artifact.get_or_create( + resource_id=_TEST_ARTIFACT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + + expected_artifact = GapicArtifact( + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + get_artifact_for_get_or_create_mock.assert_called_once_with( + name=_TEST_ARTIFACT_NAME + ) + create_artifact_mock.assert_called_once_with( + parent=_TEST_PARENT, + artifact_id=_TEST_ARTIFACT_ID, + artifact=expected_artifact, + ) + + expected_artifact.name = _TEST_ARTIFACT_NAME + assert my_artifact._gca_resource == expected_artifact + + @pytest.mark.usefixtures("get_artifact_mock") + @pytest.mark.usefixtures("create_artifact_mock") + def test_update_artifact(self, update_artifact_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_artifact = artifact._Artifact._create( + resource_id=_TEST_ARTIFACT_ID, + schema_title=_TEST_SCHEMA_TITLE, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + metadata_store_id=_TEST_METADATA_STORE, + ) + my_artifact.update(_TEST_UPDATED_METADATA) + + updated_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + + update_artifact_mock.assert_called_once_with(artifact=updated_artifact) + assert my_artifact._gca_resource == updated_artifact + + @pytest.mark.usefixtures("get_artifact_mock") + def test_list_artifacts(self, list_artifacts_mock): + aiplatform.init(project=_TEST_PROJECT) + + filter = "test-filter" + artifact_list = artifact._Artifact.list( + filter=filter, metadata_store_id=_TEST_METADATA_STORE + ) + + expected_artifact = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + + list_artifacts_mock.assert_called_once_with( + request=ListArtifactsRequest(parent=_TEST_PARENT, filter=filter,) + ) + assert len(artifact_list) == 2 + assert artifact_list[0]._gca_resource == expected_artifact + assert artifact_list[1]._gca_resource == expected_artifact diff --git a/tests/unit/aiplatform/test_metadata_store.py b/tests/unit/aiplatform/test_metadata_store.py new file mode 100644 index 0000000000..516e61d849 --- /dev/null +++ b/tests/unit/aiplatform/test_metadata_store.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from importlib import reload +from unittest import mock +from unittest.mock import patch + +import pytest +from google.api_core import operation +from google.auth import credentials as auth_credentials +from google.auth.exceptions import GoogleAuthError + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.metadata import metadata_store +from google.cloud.aiplatform_v1beta1 import MetadataServiceClient +from google.cloud.aiplatform_v1beta1 import MetadataStore as GapicMetadataStore +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1beta1.types import metadata_service + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_ALT_LOCATION = "europe-west4" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +# metadata_store +_TEST_ID = "test-id" +_TEST_DEFAULT_ID = "default" + +_TEST_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/{_TEST_ID}" +) +_TEST_ALT_LOC_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_ALT_LOCATION}/metadataStores/{_TEST_ID}" +) +_TEST_DEFAULT_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/{_TEST_DEFAULT_ID}" + +_TEST_INVALID_NAME = f"prj/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/{_TEST_ID}" + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def get_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + name=_TEST_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def get_default_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + name=_TEST_DEFAULT_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def get_metadata_store_without_name_mock(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = GapicMetadataStore( + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def create_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "create_metadata_store" + ) as create_metadata_store_mock: + create_metadata_store_lro_mock = mock.Mock(operation.Operation) + create_metadata_store_lro_mock.result.return_value = GapicMetadataStore( + name=_TEST_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_metadata_store_mock.return_value = create_metadata_store_lro_mock + yield create_metadata_store_mock + + +@pytest.fixture +def create_default_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "create_metadata_store" + ) as create_metadata_store_mock: + create_metadata_store_lro_mock = mock.Mock(operation.Operation) + create_metadata_store_lro_mock.result.return_value = GapicMetadataStore( + name=_TEST_DEFAULT_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_metadata_store_mock.return_value = create_metadata_store_lro_mock + yield create_metadata_store_mock + + +@pytest.fixture +def delete_metadata_store_mock(): + with mock.patch.object( + MetadataServiceClient, "delete_metadata_store" + ) as delete_metadata_store_mock: + delete_metadata_store_lro_mock = mock.Mock(operation.Operation) + delete_metadata_store_lro_mock.result.return_value = ( + metadata_service.DeleteMetadataStoreRequest() + ) + delete_metadata_store_mock.return_value = delete_metadata_store_lro_mock + yield delete_metadata_store_mock + + +class TestMetadataStore: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_metadata_store(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT) + metadata_store._MetadataStore(metadata_store_name=_TEST_NAME) + get_metadata_store_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_metadata_store_with_id(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore(metadata_store_name=_TEST_ID) + get_metadata_store_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_metadata_store_with_default_id(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore() + get_metadata_store_mock.assert_called_once_with(name=_TEST_DEFAULT_NAME) + + @pytest.mark.usefixtures("get_metadata_store_without_name_mock") + @patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "", "GOOGLE_APPLICATION_CREDENTIALS": ""} + ) + def test_init_metadata_store_with_id_without_project_or_location(self): + with pytest.raises(GoogleAuthError): + metadata_store._MetadataStore( + metadata_store_name=_TEST_ID, + credentials=auth_credentials.AnonymousCredentials(), + ) + + def test_init_metadata_store_with_location_override(self, get_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore( + metadata_store_name=_TEST_ID, location=_TEST_ALT_LOCATION + ) + get_metadata_store_mock.assert_called_once_with(name=_TEST_ALT_LOC_NAME) + + @pytest.mark.usefixtures("get_metadata_store_mock") + def test_init_metadata_store_with_invalid_name(self): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + metadata_store._MetadataStore(metadata_store_name=_TEST_INVALID_NAME) + + @pytest.mark.usefixtures("get_default_metadata_store_mock") + def test_init_aiplatform_with_encryption_key_name_and_create_default_metadata_store( + self, create_default_metadata_store_mock + ): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_metadata_store = metadata_store._MetadataStore._create( + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + expected_metadata_store = GapicMetadataStore( + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_default_metadata_store_mock.assert_called_once_with( + parent=_TEST_PARENT, + metadata_store_id=_TEST_DEFAULT_ID, + metadata_store=expected_metadata_store, + ) + + expected_metadata_store.name = _TEST_DEFAULT_NAME + assert my_metadata_store._gca_resource == expected_metadata_store + + @pytest.mark.usefixtures("get_metadata_store_mock") + def test_create_non_default_metadata_store(self, create_metadata_store_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_metadata_store = metadata_store._MetadataStore._create( + metadata_store_id=_TEST_ID, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + expected_metadata_store = GapicMetadataStore( + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_metadata_store_mock.assert_called_once_with( + parent=_TEST_PARENT, + metadata_store_id=_TEST_ID, + metadata_store=expected_metadata_store, + ) + + expected_metadata_store.name = _TEST_NAME + assert my_metadata_store._gca_resource == expected_metadata_store diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py new file mode 100644 index 0000000000..c63f729fd3 --- /dev/null +++ b/tests/unit/aiplatform/test_uploader.py @@ -0,0 +1,1454 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for uploader.py.""" + +import logging +import os +import re +from unittest import mock + +import grpc +import grpc_testing +from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto import graph_pb2 +from tensorboard.compat.proto import meta_graph_pb2 +from tensorboard.compat.proto import summary_pb2 +from tensorboard.compat.proto import tensor_pb2 +from tensorboard.compat.proto import types_pb2 +from tensorboard.plugins.scalar import metadata as scalars_metadata +from tensorboard.plugins.graph import metadata as graphs_metadata +from tensorboard.summary import v1 as summary_v1 +from tensorboard.uploader import logdir_loader +from tensorboard.uploader import upload_tracker +from tensorboard.uploader import util +from tensorboard.uploader.proto import server_info_pb2 +import tensorflow as tf + +from google.api_core import datetime_helpers +import google.cloud.aiplatform.tensorboard.uploader as uploader_lib +from google.cloud import storage +from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1 +from google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports import ( + grpc as transports_grpc, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_data_v1beta1 as tensorboard_data, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_experiment_v1beta1 as tensorboard_experiment_type, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_run_v1beta1 as tensorboard_run_type, +) +from google.cloud.aiplatform.compat.types import ( + tensorboard_time_series_v1beta1 as tensorboard_time_series_type, +) +from google.protobuf import timestamp_pb2 +from google.protobuf import message + +data_compat = uploader_lib.event_file_loader.data_compat +dataclass_compat = uploader_lib.event_file_loader.dataclass_compat +scalar_v2_pb = summary_v1._scalar_summary.scalar_pb +image_pb = summary_v1._image_summary.pb + +_SCALARS_HISTOGRAMS_AND_GRAPHS = frozenset( + (scalars_metadata.PLUGIN_NAME, graphs_metadata.PLUGIN_NAME,) +) + +# Sentinel for `_create_*` helpers, for arguments for which we want to +# supply a default other than the `None` used by the code under test. +_USE_DEFAULT = object() + +_TEST_EXPERIMENT_NAME = "test-experiment" +_TEST_TENSORBOARD_RESOURCE_NAME = ( + "projects/test_project/locations/us-central1/tensorboards/test_tensorboard" +) +_TEST_LOG_DIR_NAME = "/logs/foo" +_TEST_RUN_NAME = "test-run" +_TEST_ONE_PLATFORM_EXPERIMENT_NAME = "{}/experiments/{}".format( + _TEST_TENSORBOARD_RESOURCE_NAME, _TEST_EXPERIMENT_NAME +) +_TEST_ONE_PLATFORM_RUN_NAME = "{}/runs/{}".format( + _TEST_ONE_PLATFORM_EXPERIMENT_NAME, _TEST_RUN_NAME +) +_TEST_TIME_SERIES_NAME = "test-time-series" +_TEST_ONE_PLATFORM_TIME_SERIES_NAME = "{}/timeSeries/{}".format( + _TEST_ONE_PLATFORM_RUN_NAME, _TEST_TIME_SERIES_NAME +) +_TEST_BLOB_STORAGE_FOLDER = "test_folder" + + +def _create_example_graph_bytes(large_attr_size): + graph_def = graph_pb2.GraphDef() + graph_def.node.add(name="alice", op="Person") + graph_def.node.add(name="bob", op="Person") + + graph_def.node[1].attr["small"].s = b"small_attr_value" + graph_def.node[1].attr["large"].s = b"l" * large_attr_size + graph_def.node.add(name="friendship", op="Friendship", input=["alice", "bob"]) + return graph_def.SerializeToString() + + +class AbortUploadError(Exception): + """Exception used in testing to abort the upload process.""" + + +def _create_mock_client(): + # Create a stub instance (using a test channel) in order to derive a mock + # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself + # doesn't work with autospec because grpc constructs stubs via metaclassing. + def create_experiment_response( + tensorboard_experiment_id=None, + tensorboard_experiment=None, # pylint: disable=unused-argument + parent=None, + ): # pylint: disable=unused-argument + return tensorboard_experiment_type.TensorboardExperiment( + name=tensorboard_experiment_id + ) + + def create_run_response( + tensorboard_run=None, # pylint: disable=unused-argument + tensorboard_run_id=None, + parent=None, + ): # pylint: disable=unused-argument + return tensorboard_run_type.TensorboardRun(name=tensorboard_run_id) + + def create_tensorboard_time_series( + tensorboard_time_series=None, parent=None + ): # pylint: disable=unused-argument + return tensorboard_time_series_type.TensorboardTimeSeries( + name=tensorboard_time_series.display_name, + display_name=tensorboard_time_series.display_name, + ) + + test_channel = grpc_testing.channel( + service_descriptors=[], time=grpc_testing.strict_real_time() + ) + mock_client = mock.Mock( + spec=tensorboard_service_client_v1beta1.TensorboardServiceClient( + transport=transports_grpc.TensorboardServiceGrpcTransport( + channel=test_channel + ) + ) + ) + mock_client.create_tensorboard_experiment.side_effect = create_experiment_response + mock_client.create_tensorboard_run.side_effect = create_run_response + mock_client.create_tensorboard_time_series.side_effect = ( + create_tensorboard_time_series + ) + return mock_client + + +def _create_uploader( + writer_client=_USE_DEFAULT, + logdir=None, + max_scalar_request_size=_USE_DEFAULT, + max_tensor_request_size=_USE_DEFAULT, + max_tensor_point_size=_USE_DEFAULT, + max_blob_request_size=_USE_DEFAULT, + max_blob_size=_USE_DEFAULT, + logdir_poll_rate_limiter=_USE_DEFAULT, + rpc_rate_limiter=_USE_DEFAULT, + experiment_name=_TEST_EXPERIMENT_NAME, + tensorboard_resource_name=_TEST_TENSORBOARD_RESOURCE_NAME, + blob_storage_bucket=None, + blob_storage_folder=_TEST_BLOB_STORAGE_FOLDER, + description=None, + verbosity=0, # Use 0 to minimize littering the test output. + one_shot=None, +): + if writer_client is _USE_DEFAULT: + writer_client = _create_mock_client() + if max_scalar_request_size is _USE_DEFAULT: + max_scalar_request_size = 128000 + if max_tensor_request_size is _USE_DEFAULT: + max_tensor_request_size = 512000 + if max_blob_request_size is _USE_DEFAULT: + max_blob_request_size = 128000 + if max_blob_size is _USE_DEFAULT: + max_blob_size = 12345 + if max_tensor_point_size is _USE_DEFAULT: + max_tensor_point_size = 16000 + if logdir_poll_rate_limiter is _USE_DEFAULT: + logdir_poll_rate_limiter = util.RateLimiter(0) + if rpc_rate_limiter is _USE_DEFAULT: + rpc_rate_limiter = util.RateLimiter(0) + + upload_limits = server_info_pb2.UploadLimits( + max_scalar_request_size=max_scalar_request_size, + max_tensor_request_size=max_tensor_request_size, + max_tensor_point_size=max_tensor_point_size, + max_blob_request_size=max_blob_request_size, + max_blob_size=max_blob_size, + ) + + return uploader_lib.TensorBoardUploader( + experiment_name=experiment_name, + tensorboard_resource_name=tensorboard_resource_name, + writer_client=writer_client, + logdir=logdir, + allowed_plugins=_SCALARS_HISTOGRAMS_AND_GRAPHS, + upload_limits=upload_limits, + blob_storage_bucket=blob_storage_bucket, + blob_storage_folder=blob_storage_folder, + logdir_poll_rate_limiter=logdir_poll_rate_limiter, + rpc_rate_limiter=rpc_rate_limiter, + description=description, + verbosity=verbosity, + one_shot=one_shot, + ) + + +def _create_request_sender( + experiment_resource_name, api=None, allowed_plugins=_USE_DEFAULT +): + if api is _USE_DEFAULT: + api = _create_mock_client() + if allowed_plugins is _USE_DEFAULT: + allowed_plugins = _SCALARS_HISTOGRAMS_AND_GRAPHS + + upload_limits = server_info_pb2.UploadLimits( + max_scalar_request_size=128000, + max_tensor_request_size=128000, + max_tensor_point_size=52000, + ) + + rpc_rate_limiter = util.RateLimiter(0) + tensor_rpc_rate_limiter = util.RateLimiter(0) + blob_rpc_rate_limiter = util.RateLimiter(0) + + return uploader_lib._BatchedRequestSender( + experiment_resource_name=experiment_resource_name, + api=api, + allowed_plugins=allowed_plugins, + upload_limits=upload_limits, + rpc_rate_limiter=rpc_rate_limiter, + tensor_rpc_rate_limiter=tensor_rpc_rate_limiter, + blob_rpc_rate_limiter=blob_rpc_rate_limiter, + blob_storage_bucket=None, + blob_storage_folder=None, + tracker=upload_tracker.UploadTracker(verbosity=0), + ) + + +def _create_scalar_request_sender( + run_resource_id, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT +): + if api is _USE_DEFAULT: + api = _create_mock_client() + if max_request_size is _USE_DEFAULT: + max_request_size = 128000 + return uploader_lib._ScalarBatchedRequestSender( + run_resource_id=run_resource_id, + api=api, + rpc_rate_limiter=util.RateLimiter(0), + max_request_size=max_request_size, + tracker=upload_tracker.UploadTracker(verbosity=0), + ) + + +def _scalar_event(tag, value): + return event_pb2.Event(summary=scalar_v2_pb(tag, value)) + + +def _grpc_error(code, details): + # Monkey patch insertion for the methods a real grpc.RpcError would have. + error = grpc.RpcError("RPC error %r: %s" % (code, details)) + error.code = lambda: code + error.details = lambda: details + return error + + +def _timestamp_pb(nanos): + result = timestamp_pb2.Timestamp() + result.FromNanoseconds(nanos) + return result + + +class FileWriter(tf.compat.v1.summary.FileWriter): + """FileWriter for test. + + TensorFlow FileWriter uses TensorFlow's Protobuf Python binding + which is largely discouraged in TensorBoard. We do not want a + TB.Writer but require one for testing in integrational style + (writing out event files and use the real event readers). + """ + + def __init__(self, *args, **kwargs): + # Briefly enter graph mode context so this testing FileWriter can be + # created from an eager mode context without triggering a usage error. + with tf.compat.v1.Graph().as_default(): + super(FileWriter, self).__init__(*args, **kwargs) + + def add_test_summary(self, tag, simple_value=1.0, step=None): + """Convenience for writing a simple summary for a given tag.""" + value = summary_pb2.Summary.Value(tag=tag, simple_value=simple_value) + summary = summary_pb2.Summary(value=[value]) + self.add_summary(summary, global_step=step) + + def add_test_tensor_summary(self, tag, tensor, step=None, value_metadata=None): + """Convenience for writing a simple summary for a given tag.""" + value = summary_pb2.Summary.Value( + tag=tag, tensor=tensor, metadata=value_metadata + ) + summary = summary_pb2.Summary(value=[value]) + self.add_summary(summary, global_step=step) + + def add_event(self, event): + if isinstance(event, event_pb2.Event): + tf_event = tf.compat.v1.Event.FromString(event.SerializeToString()) + else: + tf_event = event + if not isinstance(event, bytes): + logging.error( + "Added TensorFlow event proto. " + "Please prefer TensorBoard copy of the proto" + ) + super(FileWriter, self).add_event(tf_event) + + def add_summary(self, summary, global_step=None): + if isinstance(summary, summary_pb2.Summary): + tf_summary = tf.compat.v1.Summary.FromString(summary.SerializeToString()) + else: + tf_summary = summary + if not isinstance(summary, bytes): + logging.error( + "Added TensorFlow summary proto. " + "Please prefer TensorBoard copy of the proto" + ) + super(FileWriter, self).add_summary(tf_summary, global_step) + + def add_session_log(self, session_log, global_step=None): + if isinstance(session_log, event_pb2.SessionLog): + tf_session_log = tf.compat.v1.SessionLog.FromString( + session_log.SerializeToString() + ) + else: + tf_session_log = session_log + if not isinstance(session_log, bytes): + logging.error( + "Added TensorFlow session_log proto. " + "Please prefer TensorBoard copy of the proto" + ) + super(FileWriter, self).add_session_log(tf_session_log, global_step) + + def add_graph(self, graph, global_step=None, graph_def=None): + if isinstance(graph_def, graph_pb2.GraphDef): + tf_graph_def = tf.compat.v1.GraphDef.FromString( + graph_def.SerializeToString() + ) + else: + tf_graph_def = graph_def + + super(FileWriter, self).add_graph( + graph, global_step=global_step, graph_def=tf_graph_def + ) + + def add_meta_graph(self, meta_graph_def, global_step=None): + if isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef): + tf_meta_graph_def = tf.compat.v1.MetaGraphDef.FromString( + meta_graph_def.SerializeToString() + ) + else: + tf_meta_graph_def = meta_graph_def + + super(FileWriter, self).add_meta_graph( + meta_graph_def=tf_meta_graph_def, global_step=global_step + ) + + +class TensorboardUploaderTest(tf.test.TestCase): + def test_create_experiment(self): + logdir = _TEST_LOG_DIR_NAME + uploader = _create_uploader(_create_mock_client(), logdir) + uploader.create_experiment() + self.assertEqual(uploader._experiment.name, _TEST_EXPERIMENT_NAME) + + def test_create_experiment_with_name(self): + logdir = _TEST_LOG_DIR_NAME + mock_client = _create_mock_client() + new_name = "This is the new name" + uploader = _create_uploader(mock_client, logdir, experiment_name=new_name) + uploader.create_experiment() + mock_client.create_tensorboard_experiment.assert_called_once() + call_args = mock_client.create_tensorboard_experiment.call_args + self.assertEqual( + call_args[1]["tensorboard_experiment"], + tensorboard_experiment_type.TensorboardExperiment(), + ) + self.assertEqual(call_args[1]["parent"], _TEST_TENSORBOARD_RESOURCE_NAME) + self.assertEqual(call_args[1]["tensorboard_experiment_id"], new_name) + + def test_create_experiment_with_description(self): + logdir = _TEST_LOG_DIR_NAME + mock_client = _create_mock_client() + new_description = """ + **description**" + may have "strange" unicode chars 🌴 \\/<> + """ + uploader = _create_uploader(mock_client, logdir, description=new_description) + uploader.create_experiment() + self.assertEqual(uploader._experiment_name, _TEST_EXPERIMENT_NAME) + mock_client.create_tensorboard_experiment.assert_called_once() + call_args = mock_client.create_tensorboard_experiment.call_args + + tb_experiment = tensorboard_experiment_type.TensorboardExperiment( + description=new_description + ) + + expected_call_args = mock.call( + parent=_TEST_TENSORBOARD_RESOURCE_NAME, + tensorboard_experiment_id=_TEST_EXPERIMENT_NAME, + tensorboard_experiment=tb_experiment, + ) + + self.assertEqual(expected_call_args, call_args) + + def test_create_experiment_with_all_metadata(self): + logdir = _TEST_LOG_DIR_NAME + mock_client = _create_mock_client() + new_description = """ + **description**" + may have "strange" unicode chars 🌴 \\/<> + """ + new_name = "This is a cool name." + uploader = _create_uploader( + mock_client, logdir, experiment_name=new_name, description=new_description + ) + uploader.create_experiment() + self.assertEqual(uploader._experiment_name, new_name) + mock_client.create_tensorboard_experiment.assert_called_once() + call_args = mock_client.create_tensorboard_experiment.call_args + + tb_experiment = tensorboard_experiment_type.TensorboardExperiment( + description=new_description + ) + expected_call_args = mock.call( + parent=_TEST_TENSORBOARD_RESOURCE_NAME, + tensorboard_experiment_id=new_name, + tensorboard_experiment=tb_experiment, + ) + self.assertEqual(call_args, expected_call_args) + + def test_start_uploading_without_create_experiment_fails(self): + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, _TEST_LOG_DIR_NAME) + with self.assertRaisesRegex(RuntimeError, "call create_experiment()"): + uploader.start_uploading() + + def test_start_uploading_scalars(self): + mock_client = _create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tensor_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_blob_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tracker = mock.MagicMock() + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ): + uploader = _create_uploader( + writer_client=mock_client, + logdir=_TEST_LOG_DIR_NAME, + # Send each Event below in a separate WriteScalarRequest + max_scalar_request_size=100, + rpc_rate_limiter=mock_rate_limiter, + verbosity=1, # In order to test the upload tracker. + ) + uploader.create_experiment() + + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [_scalar_event("1.1", 5.0), _scalar_event("1.2", 5.0)] + ), + "run 2": _apply_compat( + [_scalar_event("2.1", 5.0), _scalar_event("2.2", 5.0)] + ), + }, + { + "run 3": _apply_compat( + [_scalar_event("3.1", 5.0), _scalar_event("3.2", 5.0)] + ), + "run 4": _apply_compat( + [_scalar_event("4.1", 5.0), _scalar_event("4.2", 5.0)] + ), + "run 5": _apply_compat( + [_scalar_event("5.1", 5.0), _scalar_event("5.2", 5.0)] + ), + }, + AbortUploadError, + ] + + with mock.patch.object( + uploader, "_logdir_loader", mock_logdir_loader + ), self.assertRaises(AbortUploadError): + uploader.start_uploading() + self.assertEqual(10, mock_client.write_tensorboard_run_data.call_count) + self.assertEqual(10, mock_rate_limiter.tick.call_count) + self.assertEqual(0, mock_tensor_rate_limiter.tick.call_count) + self.assertEqual(0, mock_blob_rate_limiter.tick.call_count) + + # Check upload tracker calls. + self.assertEqual(mock_tracker.send_tracker.call_count, 2) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 10) + self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1) + self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) + self.assertEqual(mock_tracker.blob_tracker.call_count, 0) + + def test_start_uploading_scalars_one_shot(self): + """Check that one-shot uploading stops without AbortUploadError.""" + mock_client = _create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_tracker = mock.MagicMock() + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ): + uploader = _create_uploader( + writer_client=mock_client, + logdir=_TEST_LOG_DIR_NAME, + # Send each Event below in a separate WriteScalarRequest + max_scalar_request_size=100, + rpc_rate_limiter=mock_rate_limiter, + verbosity=1, # In order to test the upload tracker. + one_shot=True, + ) + uploader.create_experiment() + + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [_scalar_event("1.1", 5.0), _scalar_event("1.2", 5.0)] + ), + "run 2": _apply_compat( + [_scalar_event("2.1", 5.0), _scalar_event("2.2", 5.0)] + ), + }, + # Note the lack of AbortUploadError here. + ] + + with mock.patch.object(uploader, "_logdir_loader", mock_logdir_loader): + uploader.start_uploading() + + self.assertEqual(4, mock_client.write_tensorboard_run_data.call_count) + self.assertEqual(4, mock_rate_limiter.tick.call_count) + + # Check upload tracker calls. + self.assertEqual(mock_tracker.send_tracker.call_count, 1) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 4) + self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1) + self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) + self.assertEqual(mock_tracker.blob_tracker.call_count, 0) + + def test_upload_empty_logdir(self): + logdir = self.get_temp_dir() + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, logdir) + uploader.create_experiment() + uploader._upload_once() + mock_client.write_tensorboard_run_data.assert_not_called() + + def test_upload_polls_slowly_once_done(self): + class SuccessError(Exception): + pass + + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + upload_call_count_box = [0] + + def mock_upload_once(): + upload_call_count_box[0] += 1 + tick_count = mock_rate_limiter.tick.call_count + self.assertEqual(tick_count, upload_call_count_box[0]) + if tick_count >= 3: + raise SuccessError() + + uploader = _create_uploader( + logdir=self.get_temp_dir(), logdir_poll_rate_limiter=mock_rate_limiter, + ) + uploader._upload_once = mock_upload_once + + uploader.create_experiment() + with self.assertRaises(SuccessError): + uploader.start_uploading() + + def test_upload_swallows_rpc_failure(self): + logdir = self.get_temp_dir() + with FileWriter(logdir) as writer: + writer.add_test_summary("foo") + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, logdir) + uploader.create_experiment() + error = _grpc_error(grpc.StatusCode.INTERNAL, "Failure") + mock_client.write_tensorboard_run_data.side_effect = error + uploader._upload_once() + mock_client.write_tensorboard_run_data.assert_called_once() + + def test_upload_full_logdir(self): + logdir = self.get_temp_dir() + mock_client = _create_mock_client() + uploader = _create_uploader(mock_client, logdir) + uploader.create_experiment() + + # Convenience helpers for constructing expected requests. + data = tensorboard_data.TimeSeriesData + point = tensorboard_data.TimeSeriesDataPoint + scalar = tensorboard_data.Scalar + + # First round + writer = FileWriter(logdir) + metadata = summary_pb2.SummaryMetadata( + plugin_data=summary_pb2.SummaryMetadata.PluginData( + plugin_name="scalars", content=b"12345" + ), + data_class=summary_pb2.DATA_CLASS_SCALAR, + ) + writer.add_test_summary("foo", simple_value=5.0, step=1) + writer.add_test_summary("foo", simple_value=6.0, step=2) + writer.add_test_summary("foo", simple_value=7.0, step=3) + writer.add_test_tensor_summary( + "bar", + tensor=tensor_pb2.TensorProto(dtype=types_pb2.DT_FLOAT, float_val=[8.0]), + step=3, + value_metadata=metadata, + ) + writer.flush() + writer_a = FileWriter(os.path.join(logdir, "a")) + writer_a.add_test_summary("qux", simple_value=9.0, step=2) + writer_a.flush() + uploader._upload_once() + self.assertEqual(3, mock_client.create_tensorboard_time_series.call_count) + call_args_list = mock_client.create_tensorboard_time_series.call_args_list + request = call_args_list[1][1]["tensorboard_time_series"] + self.assertEqual("scalars", request.plugin_name) + self.assertEqual(b"12345", request.plugin_data) + + self.assertEqual(2, mock_client.write_tensorboard_run_data.call_count) + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + request1, request2 = ( + call_args_list[0][1]["time_series_data"], + call_args_list[1][1]["time_series_data"], + ) + _clear_wall_times(request1) + _clear_wall_times(request2) + + expected_request1 = [ + data( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + point(step=1, scalar=scalar(value=5.0)), + point(step=2, scalar=scalar(value=6.0)), + point(step=3, scalar=scalar(value=7.0)), + ], + ), + data( + tensorboard_time_series_id="bar", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=3, scalar=scalar(value=8.0))], + ), + ] + expected_request2 = [ + data( + tensorboard_time_series_id="qux", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=2, scalar=scalar(value=9.0))], + ) + ] + self.assertProtoEquals(expected_request1[0], request1[0]) + self.assertProtoEquals(expected_request1[1], request1[1]) + self.assertProtoEquals(expected_request2[0], request2[0]) + + mock_client.write_tensorboard_run_data.reset_mock() + + # Second round + writer.add_test_summary("foo", simple_value=10.0, step=5) + writer.add_test_summary("baz", simple_value=11.0, step=1) + writer.flush() + writer_b = FileWriter(os.path.join(logdir, "b")) + writer_b.add_test_summary("xyz", simple_value=12.0, step=1) + writer_b.flush() + uploader._upload_once() + self.assertEqual(2, mock_client.write_tensorboard_run_data.call_count) + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + request3, request4 = ( + call_args_list[0][1]["time_series_data"], + call_args_list[1][1]["time_series_data"], + ) + _clear_wall_times(request3) + _clear_wall_times(request4) + expected_request3 = [ + data( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=5, scalar=scalar(value=10.0))], + ), + data( + tensorboard_time_series_id="baz", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=11.0))], + ), + ] + expected_request4 = [ + data( + tensorboard_time_series_id="xyz", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=12.0))], + ) + ] + self.assertProtoEquals(expected_request3[0], request3[0]) + self.assertProtoEquals(expected_request3[1], request3[1]) + self.assertProtoEquals(expected_request4[0], request4[0]) + mock_client.write_tensorboard_run_data.reset_mock() + + # Empty third round + uploader._upload_once() + mock_client.write_tensorboard_run_data.assert_not_called() + + def test_verbosity_zero_creates_upload_tracker_with_verbosity_zero(self): + mock_client = _create_mock_client() + mock_tracker = mock.MagicMock() + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ) as mock_constructor: + uploader = _create_uploader( + mock_client, + _TEST_LOG_DIR_NAME, + verbosity=0, # Explicitly set verbosity to 0. + ) + uploader.create_experiment() + + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [_scalar_event("1.1", 5.0), _scalar_event("1.2", 5.0)] + ), + }, + AbortUploadError, + ] + + with mock.patch.object( + uploader, "_logdir_loader", mock_logdir_loader + ), self.assertRaises(AbortUploadError): + uploader.start_uploading() + + self.assertEqual(mock_constructor.call_count, 1) + self.assertEqual(mock_constructor.call_args[1], {"verbosity": 0}) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 1) + + def test_start_uploading_graphs(self): + mock_client = _create_mock_client() + mock_rate_limiter = mock.create_autospec(util.RateLimiter) + mock_bucket = mock.create_autospec(storage.Bucket) + mock_blob = mock.create_autospec(storage.Blob) + mock_bucket.blob.return_value = mock_blob + mock_tracker = mock.MagicMock() + + def create_time_series(tensorboard_time_series, parent=None): + return tensorboard_time_series_type.TensorboardTimeSeries( + name=_TEST_ONE_PLATFORM_TIME_SERIES_NAME, + display_name=tensorboard_time_series.display_name, + ) + + mock_client.create_tensorboard_time_series.side_effect = create_time_series + with mock.patch.object( + upload_tracker, "UploadTracker", return_value=mock_tracker + ): + uploader = _create_uploader( + writer_client=mock_client, + logdir=_TEST_LOG_DIR_NAME, + # Verify behavior with lots of small chunks + max_blob_request_size=100, + rpc_rate_limiter=mock_rate_limiter, + blob_storage_bucket=mock_bucket, + verbosity=1, # In order to test tracker. + ) + uploader.create_experiment() + + # Of course a real Event stream will never produce the same Event twice, + # but is this test context it's fine to reuse this one. + graph_event = event_pb2.Event(graph_def=_create_example_graph_bytes(950)) + expected_graph_def = graph_pb2.GraphDef.FromString(graph_event.graph_def) + mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader.get_run_events.side_effect = [ + { + "run 1": _apply_compat([graph_event, graph_event]), + "run 2": _apply_compat([graph_event, graph_event]), + }, + { + "run 3": _apply_compat([graph_event, graph_event]), + "run 4": _apply_compat([graph_event, graph_event]), + "run 5": _apply_compat([graph_event, graph_event]), + }, + AbortUploadError, + ] + + with mock.patch.object( + uploader, "_logdir_loader", mock_logdir_loader + ), self.assertRaises(AbortUploadError): + uploader.start_uploading() + + self.assertEqual(1, mock_client.create_tensorboard_experiment.call_count) + self.assertEqual(10, mock_bucket.blob.call_count) + + blob_ids = set() + for call in mock_bucket.blob.call_args_list: + request = call[0][0] + m = re.match( + "test_folder/tensorboard-.*/test-experiment/.*/{}/(.*)".format( + _TEST_TIME_SERIES_NAME + ), + request, + ) + self.assertIsNotNone(m) + blob_ids.add(m[1]) + + for call in mock_blob.upload_from_string.call_args_list: + request = call[0][0] + actual_graph_def = graph_pb2.GraphDef.FromString(request) + self.assertProtoEquals(expected_graph_def, actual_graph_def) + + for call in mock_client.write_tensorboard_run_data.call_args_list: + kargs = call[1] + time_series_data = kargs["time_series_data"] + self.assertEqual(len(time_series_data), 1) + self.assertEqual( + time_series_data[0].tensorboard_time_series_id, _TEST_TIME_SERIES_NAME + ) + self.assertEqual(len(time_series_data[0].values), 1) + blobs = time_series_data[0].values[0].blobs.values + self.assertEqual(len(blobs), 1) + self.assertIn(blobs[0].id, blob_ids) + + # Check upload tracker calls. + self.assertEqual(mock_tracker.send_tracker.call_count, 2) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 0) + self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) + self.assertEqual(mock_tracker.blob_tracker.call_count, 10) + self.assertLen(mock_tracker.blob_tracker.call_args[0], 1) + self.assertGreater(mock_tracker.blob_tracker.call_args[0][0], 0) + + def test_filter_graphs(self): + # Three graphs: one short, one long, one corrupt. + bytes_0 = _create_example_graph_bytes(123) + bytes_1 = _create_example_graph_bytes(9999) + # invalid (truncated) proto: length-delimited field 1 (0x0a) of + # length 0x7f specified, but only len("bogus") = 5 bytes given + # + bytes_2 = b"\x0a\x7fbogus" + + logdir = self.get_temp_dir() + for (i, b) in enumerate([bytes_0, bytes_1, bytes_2]): + run_dir = os.path.join(logdir, "run_%04d" % i) + event = event_pb2.Event(step=0, wall_time=123 * i, graph_def=b) + with FileWriter(run_dir) as writer: + writer.add_event(event) + + limiter = mock.create_autospec(util.RateLimiter) + limiter.tick.side_effect = [None, AbortUploadError] + mock_bucket = mock.create_autospec(storage.Bucket) + mock_blob = mock.create_autospec(storage.Blob) + mock_bucket.blob.return_value = mock_blob + mock_client = _create_mock_client() + + def create_time_series(tensorboard_time_series, parent=None): + return tensorboard_time_series_type.TensorboardTimeSeries( + name=_TEST_ONE_PLATFORM_TIME_SERIES_NAME, + display_name=tensorboard_time_series.display_name, + ) + + mock_client.create_tensorboard_time_series.side_effect = create_time_series + uploader = _create_uploader( + mock_client, + logdir, + logdir_poll_rate_limiter=limiter, + blob_storage_bucket=mock_bucket, + ) + uploader.create_experiment() + + with self.assertRaises(AbortUploadError): + uploader.start_uploading() + + actual_blobs = [] + for call in mock_blob.upload_from_string.call_args_list: + requests = call[0][0] + actual_blobs.append(requests) + + actual_graph_defs = [] + for blob in actual_blobs: + try: + actual_graph_defs.append(graph_pb2.GraphDef.FromString(blob)) + except message.DecodeError: + actual_graph_defs.append(None) + + with self.subTest("graphs with small attr values should be unchanged"): + expected_graph_def_0 = graph_pb2.GraphDef.FromString(bytes_0) + self.assertEqual(actual_graph_defs[0], expected_graph_def_0) + + with self.subTest("large attr values should be filtered out"): + expected_graph_def_1 = graph_pb2.GraphDef.FromString(bytes_1) + del expected_graph_def_1.node[1].attr["large"] + expected_graph_def_1.node[1].attr["_too_large_attrs"].list.s.append( + b"large" + ) + self.assertEqual(actual_graph_defs[1], expected_graph_def_1) + + with self.subTest("corrupt graphs should be skipped"): + self.assertLen(actual_blobs, 2) + + +class BatchedRequestSenderTest(tf.test.TestCase): + def _populate_run_from_events( + self, n_scalar_events, events, allowed_plugins=_USE_DEFAULT + ): + mock_client = _create_mock_client() + builder = _create_request_sender( + experiment_resource_name="123", + api=mock_client, + allowed_plugins=allowed_plugins, + ) + builder.send_requests({"": _apply_compat(events)}) + scalar_requests = mock_client.write_tensorboard_run_data.call_args_list + if scalar_requests: + self.assertLen(scalar_requests, 1) + self.assertLen(scalar_requests[0][1]["time_series_data"], n_scalar_events) + return scalar_requests + + def test_empty_events(self): + call_args_list = self._populate_run_from_events(0, []) + self.assertProtoEquals(call_args_list, []) + + def test_scalar_events(self): + events = [ + event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2_pb("scalar2", 5.0)), + ] + call_args_lists = self._populate_run_from_events(2, events) + scalar_tag_counts = _extract_tag_counts(call_args_lists) + self.assertEqual(scalar_tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_skips_non_scalar_events(self): + events = [ + event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), + event_pb2.Event(file_version="brain.Event:2"), + ] + call_args_list = self._populate_run_from_events(1, events) + scalar_tag_counts = _extract_tag_counts(call_args_list) + self.assertEqual(scalar_tag_counts, {"scalar1": 1}) + + def test_skips_non_scalar_events_in_scalar_time_series(self): + events = [ + event_pb2.Event(file_version="brain.Event:2"), + event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), + event_pb2.Event(summary=scalar_v2_pb("scalar2", 5.0)), + ] + call_args_list = self._populate_run_from_events(2, events) + scalar_tag_counts = _extract_tag_counts(call_args_list) + self.assertEqual(scalar_tag_counts, {"scalar1": 1, "scalar2": 1}) + + def test_skips_events_from_disallowed_plugins(self): + event = event_pb2.Event( + step=1, wall_time=123.456, summary=scalar_v2_pb("foo", 5.0) + ) + call_args_lists = self._populate_run_from_events( + 0, [event], allowed_plugins=frozenset("not-scalars"), + ) + self.assertEqual(call_args_lists, []) + + def test_remembers_first_metadata_in_time_series(self): + scalar_1 = event_pb2.Event(summary=scalar_v2_pb("loss", 4.0)) + scalar_2 = event_pb2.Event(summary=scalar_v2_pb("loss", 3.0)) + scalar_2.summary.value[0].ClearField("metadata") + events = [ + event_pb2.Event(file_version="brain.Event:2"), + scalar_1, + scalar_2, + ] + call_args_list = self._populate_run_from_events(1, events) + scalar_tag_counts = _extract_tag_counts(call_args_list) + self.assertEqual(scalar_tag_counts, {"loss": 2}) + + def test_expands_multiple_values_in_event(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + event.summary.value.add(tag="foo", simple_value=2.0) + event.summary.value.add(tag="foo", simple_value=3.0) + call_args_list = self._populate_run_from_events(1, [event]) + + time_series_data = tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=1.0), + ), + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=2.0), + ), + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=3.0), + ), + ], + ) + + self.assertProtoEquals( + time_series_data, call_args_list[0][1]["time_series_data"][0] + ) + + +class ScalarBatchedRequestSenderTest(tf.test.TestCase): + def _add_events(self, sender, events): + for event in events: + for value in event.summary.value: + sender.add_event(event, value, value.metadata) + + def _add_events_and_flush(self, events, expected_n_time_series): + mock_client = _create_mock_client() + sender = _create_scalar_request_sender( + run_resource_id=_TEST_RUN_NAME, api=mock_client, + ) + self._add_events(sender, events) + sender.flush() + + requests = mock_client.write_tensorboard_run_data.call_args_list + self.assertLen(requests, 1) + self.assertLen(requests[0][1]["time_series_data"], expected_n_time_series) + return requests[0] + + def test_aggregation_by_tag(self): + def make_event(step, wall_time, tag, value): + return event_pb2.Event( + step=step, wall_time=wall_time, summary=scalar_v2_pb(tag, value), + ) + + events = [ + make_event(1, 1.0, "one", 11.0), + make_event(1, 2.0, "two", 22.0), + make_event(2, 3.0, "one", 33.0), + make_event(2, 4.0, "two", 44.0), + make_event(1, 5.0, "one", 55.0), # Should preserve duplicate step=1. + make_event(1, 6.0, "three", 66.0), + ] + call_args = self._add_events_and_flush(events, 3) + ts_data = call_args[1]["time_series_data"] + tag_data = { + ts.tensorboard_time_series_id: [ + ( + value.step, + value.wall_time.timestamp_pb().ToSeconds(), + value.scalar.value, + ) + for value in ts.values + ] + for ts in ts_data + } + self.assertEqual( + tag_data, + { + "one": [(1, 1.0, 11.0), (2, 3.0, 33.0), (1, 5.0, 55.0)], + "two": [(1, 2.0, 22.0), (2, 4.0, 44.0)], + "three": [(1, 6.0, 66.0)], + }, + ) + + def test_v1_summary(self): + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=5.0) + call_args = self._add_events_and_flush(_apply_compat([event]), 1) + + expected_call_args = mock.call( + tensorboard_run=_TEST_RUN_NAME, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=5.0), + ) + ], + ) + ], + ) + self.assertEqual(expected_call_args, call_args) + + def test_v1_summary_tb_summary(self): + tf_summary = summary_v1.scalar_pb("foo", 5.0) + tb_summary = summary_pb2.Summary.FromString(tf_summary.SerializeToString()) + event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) + call_args = self._add_events_and_flush(_apply_compat([event]), 1) + + expected_call_args = mock.call( + tensorboard_run=_TEST_RUN_NAME, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="scalar_summary", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=5.0), + ) + ], + ) + ], + ) + self.assertEqual(expected_call_args, call_args) + + def test_v2_summary(self): + event = event_pb2.Event( + step=1, wall_time=123.456, summary=scalar_v2_pb("foo", 5.0) + ) + call_args = self._add_events_and_flush(_apply_compat([event]), 1) + + expected_call_args = mock.call( + tensorboard_run=_TEST_RUN_NAME, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, + wall_time=_timestamp_pb(123456000000), + scalar=tensorboard_data.Scalar(value=5.0), + ) + ], + ) + ], + ) + + self.assertEqual(expected_call_args, call_args) + + def test_propagates_experiment_deletion(self): + event = event_pb2.Event(step=1) + event.summary.value.add(tag="foo", simple_value=1.0) + + mock_client = _create_mock_client() + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, _apply_compat([event])) + + error = _grpc_error(grpc.StatusCode.NOT_FOUND, "nope") + mock_client.write_tensorboard_run_data.side_effect = error + with self.assertRaises(uploader_lib.ExperimentNotFoundError): + sender.flush() + + def test_no_budget_for_base_request(self): + mock_client = _create_mock_client() + long_run_id = "A" * 12 + with self.assertRaises(uploader_lib._OutOfSpaceError) as cm: + _create_scalar_request_sender( + run_resource_id=long_run_id, api=mock_client, max_request_size=12, + ) + self.assertEqual(str(cm.exception), "Byte budget too small for base request") + + def test_no_room_for_single_point(self): + mock_client = _create_mock_client() + event = event_pb2.Event(step=1, wall_time=123.456) + event.summary.value.add(tag="foo", simple_value=1.0) + sender = _create_scalar_request_sender("123", mock_client, max_request_size=12) + with self.assertRaises(RuntimeError) as cm: + self._add_events(sender, [event]) + self.assertEqual(str(cm.exception), "add_event failed despite flush") + + def test_break_at_run_boundary(self): + mock_client = _create_mock_client() + # Choose run name sizes such that one run fits in a 1024 byte request, + # but not two. + long_run_1 = "A" * 768 + long_run_2 = "B" * 768 + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + + sender_1 = _create_scalar_request_sender( + long_run_1, + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + + sender_2 = _create_scalar_request_sender( + long_run_2, + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + self._add_events(sender_1, _apply_compat([event_1])) + self._add_events(sender_2, _apply_compat([event_2])) + sender_1.flush() + sender_2.flush() + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + + for call_args in call_args_list: + _clear_wall_times(call_args[1]["time_series_data"]) + + # Expect two calls despite a single explicit call to flush(). + + expected = [ + mock.call( + tensorboard_run=long_run_1, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, scalar=tensorboard_data.Scalar(value=1.0) + ) + ], + ) + ], + ), + mock.call( + tensorboard_run=long_run_2, + time_series_data=[ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="bar", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=2, scalar=tensorboard_data.Scalar(value=-2.0) + ) + ], + ) + ], + ), + ] + + self.assertEqual(expected[0], call_args_list[0]) + self.assertEqual(expected[1], call_args_list[1]) + + def test_break_at_tag_boundary(self): + mock_client = _create_mock_client() + # Choose tag name sizes such that one tag fits in a 1024 byte request, + # but not two. Note that tag names appear in both `Tag.name` and the + # summary metadata. + long_tag_1 = "a" * 384 + long_tag_2 = "b" * 384 + event = event_pb2.Event(step=1) + event.summary.value.add(tag=long_tag_1, simple_value=1.0) + event.summary.value.add(tag=long_tag_2, simple_value=2.0) + + sender = _create_scalar_request_sender( + "train", + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + self._add_events(sender, _apply_compat([event])) + sender.flush() + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + + request1 = call_args_list[0][1]["time_series_data"] + _clear_wall_times(request1) + + # Convenience helpers for constructing expected requests. + data = tensorboard_data.TimeSeriesData + point = tensorboard_data.TimeSeriesDataPoint + scalar = tensorboard_data.Scalar + + expected_request1 = [ + data( + tensorboard_time_series_id=long_tag_1, + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=1.0))], + ), + data( + tensorboard_time_series_id=long_tag_2, + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=2.0))], + ), + ] + self.assertProtoEquals(expected_request1[0], request1[0]) + self.assertProtoEquals(expected_request1[1], request1[1]) + + def test_break_at_scalar_point_boundary(self): + mock_client = _create_mock_client() + point_count = 2000 # comfortably saturates a single 1024-byte request + events = [] + for step in range(point_count): + summary = scalar_v2_pb("loss", -2.0 * step) + if step > 0: + summary.value[0].ClearField("metadata") + events.append(event_pb2.Event(summary=summary, step=step)) + + sender = _create_scalar_request_sender( + "train", + mock_client, + # Set a limit to request size + max_request_size=1024, + ) + self._add_events(sender, _apply_compat(events)) + sender.flush() + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + + for call_args in call_args_list: + _clear_wall_times(call_args[1]["time_series_data"]) + + self.assertGreater(len(call_args_list), 1) + self.assertLess(len(call_args_list), point_count) + # This is the observed number of requests when running the test. There + # is no reasonable way to derive this value from just reading the code. + # The number of requests does not have to be 37 to be correct but if it + # changes it probably warrants some investigation or thought. + self.assertEqual(37, len(call_args_list)) + + total_points_in_result = 0 + for call_args in call_args_list: + self.assertLen(call_args[1]["time_series_data"], 1) + self.assertEqual(call_args[1]["tensorboard_run"], "train") + time_series_data = call_args[1]["time_series_data"][0] + self.assertEqual(time_series_data.tensorboard_time_series_id, "loss") + for point in time_series_data.values: + self.assertEqual(point.step, total_points_in_result) + self.assertEqual(point.scalar.value, -2.0 * point.step) + total_points_in_result += 1 + self.assertEqual(total_points_in_result, point_count) + + def test_prunes_tags_and_runs(self): + mock_client = _create_mock_client() + event_1 = event_pb2.Event(step=1) + event_1.summary.value.add(tag="foo", simple_value=1.0) + event_2 = event_pb2.Event(step=2) + event_2.summary.value.add(tag="bar", simple_value=-2.0) + + add_point_call_count_box = [0] + + def mock_add_point(byte_budget_manager_self, point): + # Simulate out-of-space error the first time that we try to store + # the second point. + add_point_call_count_box[0] += 1 + if add_point_call_count_box[0] == 2: + raise uploader_lib._OutOfSpaceError() + + with mock.patch.object( + uploader_lib._ByteBudgetManager, "add_point", mock_add_point, + ): + sender = _create_scalar_request_sender("123", mock_client) + self._add_events(sender, _apply_compat([event_1])) + self._add_events(sender, _apply_compat([event_2])) + sender.flush() + + call_args_list = mock_client.write_tensorboard_run_data.call_args_list + request1, request2 = ( + call_args_list[0][1]["time_series_data"], + call_args_list[1][1]["time_series_data"], + ) + _clear_wall_times(request1) + _clear_wall_times(request2) + + # Convenience helpers for constructing expected requests. + data = tensorboard_data.TimeSeriesData + point = tensorboard_data.TimeSeriesDataPoint + scalar = tensorboard_data.Scalar + + expected_request1 = [ + data( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=1, scalar=scalar(value=1.0))], + ) + ] + + expected_request2 = [ + data( + tensorboard_time_series_id="bar", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[point(step=2, scalar=scalar(value=-2.0))], + ) + ] + self.assertProtoEquals(expected_request1[0], request1[0]) + self.assertProtoEquals(expected_request2[0], request2[0]) + + def test_wall_time_precision(self): + # Test a wall time that is exactly representable in float64 but has enough + # digits to incur error if converted to nanoseconds the naive way (* 1e9). + event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119) + event1.summary.value.add(tag="foo", simple_value=1.0) + # Test a wall time where as a float64, the fractional part on its own will + # introduce error if truncated to 9 decimal places instead of rounded. + event2 = event_pb2.Event(step=2, wall_time=1.000000002) + event2.summary.value.add(tag="foo", simple_value=2.0) + call_args = self._add_events_and_flush(_apply_compat([event1, event2]), 1) + self.assertEqual( + datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb( + _timestamp_pb(1567808404765432119) + ), + call_args[1]["time_series_data"][0].values[0].wall_time, + ) + self.assertEqual( + datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb( + _timestamp_pb(1000000002) + ), + call_args[1]["time_series_data"][0].values[1].wall_time, + ) + + +class VarintCostTest(tf.test.TestCase): + def test_varint_cost(self): + self.assertEqual(uploader_lib._varint_cost(0), 1) + self.assertEqual(uploader_lib._varint_cost(7), 1) + self.assertEqual(uploader_lib._varint_cost(127), 1) + self.assertEqual(uploader_lib._varint_cost(128), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128 - 1), 2) + self.assertEqual(uploader_lib._varint_cost(128 * 128), 3) + + +def _clear_wall_times(repeated_time_series_data): + """Clears the wall_time fields in a TimeSeriesData to be deterministic. + + Args: + repeated_time_series_data: Iterable of tensorboard_data.TimeSeriesData. + """ + + for time_series_data in repeated_time_series_data: + for value in time_series_data.values: + value.wall_time = None + + +def _apply_compat(events): + initial_metadata = {} + for event in events: + event = data_compat.migrate_event(event) + events = dataclass_compat.migrate_event( + event, initial_metadata=initial_metadata + ) + for migrated_event in events: + yield migrated_event + + +def _extract_tag_counts(call_args_list): + return { + ts_data.tensorboard_time_series_id: len(ts_data.values) + for call_args in call_args_list + for ts_data in call_args[1]["time_series_data"] + } + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 3032475069..c5ce327db8 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -44,7 +44,8 @@ ("projects/123456/locations/us-central1/datasets/987654", True), ("projects/857392/locations/us-central1/trainingPipelines/347292", True), ("projects/acme-co-proj-1/locations/us-central1/datasets/123456", True), - ("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", False), + ("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", True), + ("projects/acme-co-proj-1/locations/us-central1/datasets/abc-def", True), ("project/123456/locations/us-central1/datasets/987654", False), ("project//locations//datasets/987654", False), ("locations/europe-west4/datasets/987654", False), @@ -101,6 +102,18 @@ def test_extract_fields_from_resource_name_with_extracted_fields( "batchPredictionJobs", False, ), + # Expects pattern "projects/.../locations/.../metadataStores/.../contexts/..." + ( + "projects/857392/locations/us-central1/metadataStores/default/contexts/123", + "metadataStores/default/contexts", + True, + ), + # Expects pattern "projects/.../locations/.../tensorboards/.../experiments/.../runs/.../timeSeries/..." + ( + "projects/857392/locations/us-central1/tensorboards/123/experiments/456/runs/789/timeSeries/1", + "tensorboards/123/experiments/456/runs/789/timeSeries", + True, + ), ], ) def test_extract_fields_from_resource_name_with_resource_noun( @@ -140,6 +153,18 @@ def test_invalid_region_does_not_raise_with_valid_region(): "us-west20", "projects/857392/locations/us-central1/trainingPipelines/347292", ), + ( + "metadataStores/default/contexts", + "123456", + "europe-west4", + "projects/857392/locations/us-central1/metadataStores/default/contexts/123", + ), + ( + "tensorboards/123/experiments/456/runs/789/timeSeries", + "857392", + "us-central1", + "projects/857392/locations/us-central1/tensorboards/123/experiments/456/runs/789/timeSeries/1", + ), ], ) def test_full_resource_name_with_full_name( @@ -174,6 +199,20 @@ def test_full_resource_name_with_full_name( "us-central1", "projects/857392/locations/us-central1/trainingPipelines/347292", ), + ( + "123", + "metadataStores/default/contexts", + "857392", + "us-central1", + "projects/857392/locations/us-central1/metadataStores/default/contexts/123", + ), + ( + "1", + "tensorboards/123/experiments/456/runs/789/timeSeries", + "857392", + "us-central1", + "projects/857392/locations/us-central1/tensorboards/123/experiments/456/runs/789/timeSeries/1", + ), ], ) def test_full_resource_name_with_partial_name(