Skip to content

Commit

Permalink
chore: detect current Experiment and ExperimentRun from env variable
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 518628773
  • Loading branch information
jaycee-li authored and copybara-github committed Mar 22, 2023
1 parent 41cd943 commit a361948
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -275,11 +275,11 @@ def create_run(
"""

self._vertex_experiment = (
aiplatform.metadata.metadata._experiment_tracker._experiment
aiplatform.metadata.metadata._experiment_tracker.experiment
)

currently_active_run = (
aiplatform.metadata.metadata._experiment_tracker._experiment_run
aiplatform.metadata.metadata._experiment_tracker.experiment_run
)

parent_run_id = None
Expand Down Expand Up @@ -374,7 +374,7 @@ def update_run_info(
self._run_map[run_id].autocreate
and run_status in _MLFLOW_TERMINAL_RUN_STATES
and self._run_map[run_id].experiment_run
is aiplatform.metadata.metadata._experiment_tracker._experiment_run
is aiplatform.metadata.metadata._experiment_tracker.experiment_run
):
aiplatform.metadata.metadata._experiment_tracker.end_run(
state=execution_v1.Execution.State.COMPLETE
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/metadata/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -79,3 +79,6 @@
_VERTEX_EXPERIMENT_TB_EXPERIMENT_LABEL = {
"vertex_tensorboard_experiment_source": "vertex_experiment"
}

ENV_EXPERIMENT_KEY = "AIP_EXPERIMENT_NAME"
ENV_EXPERIMENT_RUN_KEY = "AIP_EXPERIMENT_RUN_NAME"
73 changes: 43 additions & 30 deletions google/cloud/aiplatform/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import logging
import os
from typing import Dict, Union, Optional, Any, List

from google.api_core import exceptions
Expand Down Expand Up @@ -212,19 +213,34 @@ def reset(self):
@property
def experiment_name(self) -> Optional[str]:
"""Return the currently set experiment name, if experiment is not set, return None"""
if self._experiment:
return self._experiment.name
if self.experiment:
return self.experiment.name
return None

@property
def experiment(self) -> Optional[experiment_resources.Experiment]:
"Returns the currently set Experiment."
return self._experiment
"""Returns the currently set Experiment or Experiment set via env variable AIP_EXPERIMENT_NAME."""
if self._experiment:
return self._experiment
if os.getenv(constants.ENV_EXPERIMENT_KEY):
self._experiment = experiment_resources.Experiment.get(
os.getenv(constants.ENV_EXPERIMENT_KEY)
)
return self._experiment
return None

@property
def experiment_run(self) -> Optional[experiment_run_resource.ExperimentRun]:
"""Returns the currently set experiment run."""
return self._experiment_run
"""Returns the currently set experiment run or experiment run set via env variable AIP_EXPERIMENT_RUN_NAME."""
if self._experiment_run:
return self._experiment_run
if os.getenv(constants.ENV_EXPERIMENT_RUN_KEY):
self._experiment_run = experiment_run_resource.ExperimentRun.get(
os.getenv(constants.ENV_EXPERIMENT_RUN_KEY),
experiment=self.experiment,
)
return self._experiment_run
return None

def set_experiment(
self,
Expand Down Expand Up @@ -384,18 +400,18 @@ def start_run(
but with a different schema.
"""

if not self._experiment:
if not self.experiment:
raise ValueError(
"No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') "
"before invoking start_run. "
)

if self._experiment_run:
if self.experiment_run:
self.end_run()

if resume:
self._experiment_run = experiment_run_resource.ExperimentRun(
run_name=run, experiment=self._experiment
run_name=run, experiment=self.experiment
)
if tensorboard:
self._experiment_run.assign_backing_tensorboard(tensorboard=tensorboard)
Expand All @@ -406,7 +422,7 @@ def start_run(

else:
self._experiment_run = experiment_run_resource.ExperimentRun.create(
run_name=run, experiment=self._experiment, tensorboard=tensorboard
run_name=run, experiment=self.experiment, tensorboard=tensorboard
)

return self._experiment_run
Expand All @@ -426,10 +442,10 @@ def end_run(
"""
self._validate_experiment_and_run(method_name="end_run")
try:
self._experiment_run.end_run(state=state)
self.experiment_run.end_run(state=state)
except exceptions.NotFound:
_LOGGER.warning(
f"Experiment run {self._experiment_run.name} was not found."
f"Experiment run {self.experiment_run.name} was not found."
"It may have been deleted"
)
finally:
Expand Down Expand Up @@ -481,7 +497,7 @@ def autolog(self, disable=False):
logging.getLogger("mlflow.utils.autologging_utils").removeFilter(
_MLFlowLogFilter()
)
elif not self._experiment:
elif not self.experiment:
raise ValueError(
"No experiment set. Make sure to call aiplatform.init(experiment='my-experiment') "
"before calling aiplatform.autolog()."
Expand Down Expand Up @@ -516,7 +532,7 @@ def log_params(self, params: Dict[str, Union[float, int, str]]):

self._validate_experiment_and_run(method_name="log_params")
# query the latest run execution resource before logging.
self._experiment_run.log_params(params=params)
self.experiment_run.log_params(params=params)

def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
"""Log single or multiple Metrics with specified key and value pairs.
Expand All @@ -535,7 +551,7 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):

self._validate_experiment_and_run(method_name="log_metrics")
# query the latest metrics artifact resource before logging.
self._experiment_run.log_metrics(metrics=metrics)
self.experiment_run.log_metrics(metrics=metrics)

def log_classification_metrics(
self,
Expand Down Expand Up @@ -584,7 +600,7 @@ def log_classification_metrics(

self._validate_experiment_and_run(method_name="log_classification_metrics")
# query the latest metrics artifact resource before logging.
return self._experiment_run.log_classification_metrics(
return self.experiment_run.log_classification_metrics(
display_name=display_name,
labels=labels,
matrix=matrix,
Expand Down Expand Up @@ -666,7 +682,7 @@ def log_model(
ValueError: if model type is not supported.
"""
self._validate_experiment_and_run(method_name="log_model")
self._experiment_run.log_model(
self.experiment_run.log_model(
model=model,
artifact_id=artifact_id,
uri=uri,
Expand All @@ -688,12 +704,12 @@ def _validate_experiment_and_run(self, method_name: str):
ValueError: If Experiment or Run are not set.
"""

if not self._experiment:
if not self.experiment:
raise ValueError(
f"No experiment set. Make sure to call aiplatform.init(experiment='my-experiment') "
f"before trying to {method_name}. "
)
if not self._experiment_run:
if not self.experiment_run:
raise ValueError(
f"No run set. Make sure to call aiplatform.start_run('my-run') before trying to {method_name}. "
)
Expand Down Expand Up @@ -737,7 +753,7 @@ def get_experiment_df(
"""

if not experiment:
experiment = self._experiment
experiment = self.experiment
else:
experiment = experiment_resources.Experiment(experiment)

Expand All @@ -762,7 +778,7 @@ def log(
Optional. Vertex PipelineJob to associate to this Experiment Run.
"""
self._validate_experiment_and_run(method_name="log")
self._experiment_run.log(pipeline_job=pipeline_job)
self.experiment_run.log(pipeline_job=pipeline_job)

def log_time_series_metrics(
self,
Expand Down Expand Up @@ -806,7 +822,7 @@ def log_time_series_metrics(
RuntimeError: If current experiment run doesn't have a backing Tensorboard resource.
"""
self._validate_experiment_and_run(method_name="log_time_series_metrics")
self._experiment_run.log_time_series_metrics(
self.experiment_run.log_time_series_metrics(
metrics=metrics, step=step, wall_time=wall_time
)

Expand Down Expand Up @@ -882,18 +898,15 @@ def start_execution(
ValueError: If creating a new executin and schema_title is not provided.
"""

if (
self._experiment_run
and not self._experiment_run._is_legacy_experiment_run()
):
if project and project != self._experiment_run.project:
if self.experiment_run and not self.experiment_run._is_legacy_experiment_run():
if project and project != self.experiment_run.project:
raise ValueError(
f"Currently set Experiment run project {self._experiment_run.project} must"
f"Currently set Experiment run project {self.experiment_run.project} must"
f"match provided project {project}"
)
if location and location != self._experiment_run.location:
if location and location != self.experiment_run.location:
raise ValueError(
f"Currently set Experiment run location {self._experiment_run.location} must"
f"Currently set Experiment run location {self.experiment_run.location} must"
f"match provided location {project}"
)

Expand Down
72 changes: 72 additions & 0 deletions tests/unit/aiplatform/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import copy
from importlib import reload
from unittest import mock
Expand Down Expand Up @@ -1154,6 +1156,76 @@ def test_init_experiment_wrong_schema(self):
experiment=_TEST_EXPERIMENT,
)

@pytest.mark.usefixtures("get_metadata_store_mock", "get_experiment_mock")
def test_init_experiment_from_env(self):
os.environ["AIP_EXPERIMENT_NAME"] = _TEST_EXPERIMENT

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

exp = metadata._experiment_tracker.experiment
assert exp.name == _TEST_EXPERIMENT

del os.environ["AIP_EXPERIMENT_NAME"]

@pytest.mark.usefixtures(
"get_metadata_store_mock",
)
def test_start_run_from_env_experiment(
self,
get_experiment_mock,
create_experiment_run_context_mock,
add_context_children_mock,
):
os.environ["AIP_EXPERIMENT_NAME"] = _TEST_EXPERIMENT

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

aiplatform.start_run(_TEST_RUN)

get_experiment_mock.assert_called_with(
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
)

_TRUE_CONTEXT = copy.deepcopy(_EXPERIMENT_RUN_MOCK)
_TRUE_CONTEXT.name = None

create_experiment_run_context_mock.assert_called_with(
parent=_TEST_METADATASTORE,
context=_TRUE_CONTEXT,
context_id=_EXPERIMENT_RUN_MOCK.name.split("/")[-1],
)

add_context_children_mock.assert_called_with(
context=_EXPERIMENT_MOCK.name, child_contexts=[_EXPERIMENT_RUN_MOCK.name]
)

del os.environ["AIP_EXPERIMENT_NAME"]

@pytest.mark.usefixtures(
"get_metadata_store_mock",
"get_experiment_run_mock",
"get_tensorboard_run_artifact_not_found_mock",
)
def test_init_experiment_run_from_env(self):
os.environ["AIP_EXPERIMENT_RUN_NAME"] = _TEST_RUN

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
experiment=_TEST_EXPERIMENT,
)

run = metadata._experiment_tracker.experiment_run
assert run.name == _TEST_RUN

del os.environ["AIP_EXPERIMENT_RUN_NAME"]

def test_get_experiment(self, get_experiment_mock):
aiplatform.init(
project=_TEST_PROJECT,
Expand Down

0 comments on commit a361948

Please sign in to comment.