diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 5afeec4d26..8f620c2371 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -83,6 +83,12 @@ def init( Example tensorboard resource name format: "projects/123/locations/us-central1/tensorboards/456" + + If `experiment_tensorboard` is provided and `experiment` is not, + the provided `experiment_tensorboard` will be set as the global Tensorboard. + Any subsequent calls to aiplatform.init() with `experiment` and without + `experiment_tensorboard` will automatically assign the global Tensorboard + to the `experiment`. staging_bucket (str): The default staging bucket to use to stage artifacts when making API calls. In the form gs://... credentials (google.auth.credentials.Credentials): The default custom @@ -106,7 +112,6 @@ def init( Raises: ValueError: If experiment_description is provided but experiment is not. - If experiment_tensorboard is provided but experiment is not. """ if experiment_description and experiment is None: @@ -114,9 +119,12 @@ def init( "Experiment needs to be set in `init` in order to add experiment descriptions." ) - if experiment_tensorboard and experiment is None: - raise ValueError( - "Experiment needs to be set in `init` in order to add experiment_tensorboard." + if experiment_tensorboard: + metadata._experiment_tracker.set_tensorboard( + tensorboard=experiment_tensorboard, + project=project, + location=location, + credentials=credentials, ) # reset metadata_service config if project or location is updated. diff --git a/google/cloud/aiplatform/metadata/experiment_resources.py b/google/cloud/aiplatform/metadata/experiment_resources.py index 75c7854adc..7cee1a17c0 100644 --- a/google/cloud/aiplatform/metadata/experiment_resources.py +++ b/google/cloud/aiplatform/metadata/experiment_resources.py @@ -326,6 +326,13 @@ def resource_name(self) -> str: """The Metadata context resource name of this experiment.""" return self._metadata_context.resource_name + @property + def backing_tensorboard_resource_name(self) -> Optional[str]: + """The Tensorboard resource associated with this Experiment if there is one.""" + return self._metadata_context.metadata.get( + constants._BACKING_TENSORBOARD_RESOURCE_KEY + ) + def delete(self, *, delete_backing_tensorboard_runs: bool = False): """Deletes this experiment all the experiment runs under this experiment diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index 92c484f34f..8245fcd738 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -186,6 +186,7 @@ class _ExperimentTracker: def __init__(self): self._experiment: Optional[experiment_resources.Experiment] = None self._experiment_run: Optional[experiment_run_resource.ExperimentRun] = None + self._global_tensorboard: Optional[tensorboard_resource.Tensorboard] = None def reset(self): """Resets this experiment tracker, clearing the current experiment and run.""" @@ -235,11 +236,47 @@ def set_experiment( experiment_name=experiment, description=description ) - if backing_tensorboard: + backing_tb = backing_tensorboard or self._global_tensorboard + + current_backing_tb = experiment.backing_tensorboard_resource_name + + if not current_backing_tb and backing_tb: experiment.assign_backing_tensorboard(tensorboard=backing_tensorboard) self._experiment = experiment + def set_tensorboard( + self, + tensorboard: Union[ + tensorboard_resource.Tensorboard, + str, + ], + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Sets the global Tensorboard resource for this session. + + Args: + tensorboard (Union[str, aiplatform.Tensorboard]): + Required. The Tensorboard resource to set as the global Tensorboard. + project (str): + Optional. Project associated with this Tensorboard resource. + location (str): + Optional. Location associated with this Tensorboard resource. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to set this Tensorboard resource. + """ + if isinstance(tensorboard, str): + tensorboard = tensorboard_resource.Tensorboard( + tensorboard, + project=project, + location=location, + credentials=credentials, + ) + + self._global_tensorboard = tensorboard + def start_run( self, run: str, diff --git a/tests/system/aiplatform/test_experiments.py b/tests/system/aiplatform/test_experiments.py index 83d96d945e..5ed7db309a 100644 --- a/tests/system/aiplatform/test_experiments.py +++ b/tests/system/aiplatform/test_experiments.py @@ -416,3 +416,49 @@ def test_delete_experiment(self): with pytest.raises(exceptions.NotFound): aiplatform.Experiment(experiment_name=self._experiment_name) + + def test_init_associates_global_tensorboard_to_experiment(self, shared_state): + + tensorboard = aiplatform.Tensorboard.create( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + display_name=self._make_display_name("")[:64], + ) + + shared_state["resources"] = [tensorboard] + + aiplatform.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + experiment_tensorboard=tensorboard, + ) + + assert ( + aiplatform.metadata.metadata._experiment_tracker._global_tensorboard + == tensorboard + ) + + new_experiment_name = self._make_display_name("")[:64] + new_experiment_resource = aiplatform.Experiment.create( + experiment_name=new_experiment_name + ) + + shared_state["resources"].append(new_experiment_resource) + + aiplatform.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + experiment=new_experiment_name, + ) + + assert ( + new_experiment_resource._lookup_backing_tensorboard().resource_name + == tensorboard.resource_name + ) + + assert ( + new_experiment_resource._metadata_context.metadata.get( + aiplatform.metadata.constants._BACKING_TENSORBOARD_RESOURCE_KEY + ) + == tensorboard.resource_name + ) diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index ae051594ff..d4bf108cea 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -44,6 +44,10 @@ _TEST_STAGING_BUCKET = "test-bucket" _TEST_NETWORK = "projects/12345/global/networks/myVPC" +# tensorboard +_TEST_TENSORBOARD_ID = "1028944691210842416" +_TEST_TENSORBOARD_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/tensorboards/{_TEST_TENSORBOARD_ID}" + @pytest.mark.usefixtures("google_auth_mock") class TestInit: @@ -115,6 +119,59 @@ def test_init_experiment_sets_experiment_with_description( backing_tensorboard=None, ) + @patch.object(_experiment_tracker, "set_tensorboard") + def test_init_with_experiment_tensorboard_id_sets_global_tensorboard( + self, set_tensorboard_mock + ): + creds = credentials.AnonymousCredentials() + initializer.global_config.init( + experiment_tensorboard=_TEST_TENSORBOARD_ID, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=creds, + ) + + set_tensorboard_mock.assert_called_once_with( + tensorboard=_TEST_TENSORBOARD_ID, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=creds, + ) + + @patch.object(_experiment_tracker, "set_tensorboard") + def test_init_with_experiment_tensorboard_resource_sets_global_tensorboard( + self, set_tensorboard_mock + ): + initializer.global_config.init(experiment_tensorboard=_TEST_TENSORBOARD_NAME) + + set_tensorboard_mock.assert_called_once_with( + tensorboard=_TEST_TENSORBOARD_NAME, + project=None, + location=None, + credentials=None, + ) + + @patch.object(_experiment_tracker, "set_tensorboard") + @patch.object(_experiment_tracker, "set_experiment") + def test_init_experiment_without_tensorboard_uses_global_tensorboard( + self, + set_tensorboard_mock, + set_experiment_mock, + ): + + initializer.global_config.init(experiment_tensorboard=_TEST_TENSORBOARD_NAME) + + initializer.global_config.init( + experiment=_TEST_EXPERIMENT, + ) + + set_experiment_mock.assert_called_once_with( + tensorboard=_TEST_TENSORBOARD_NAME, + project=None, + location=None, + credentials=None, + ) + def test_init_experiment_description_fail_without_experiment(self): with pytest.raises(ValueError): initializer.global_config.init(experiment_description=_TEST_DESCRIPTION)