-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add get_associated_experiment method to pipeline_jobs #1476
Changes from all commits
39cdbcf
480359a
10efff1
bb20348
4d5c10e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,10 @@ | |
from google.cloud import aiplatform | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also update the experiment system tests to test this functinality. Add a test right after this test: https://github.com/googleapis/python-aiplatform/blob/main/tests/system/aiplatform/test_experiments.py#L263 It should be able to get the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a call to |
||
from google.cloud.aiplatform import base | ||
from google.cloud.aiplatform import initializer | ||
from google.cloud.aiplatform_v1 import Context as GapicContext | ||
from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore | ||
from google.cloud.aiplatform.metadata import constants | ||
from google.cloud.aiplatform_v1 import MetadataServiceClient | ||
from google.cloud.aiplatform import pipeline_jobs | ||
from google.cloud.aiplatform.compat.types import pipeline_failure_policy | ||
from google.cloud import storage | ||
|
@@ -188,6 +192,22 @@ | |
) | ||
_TEST_PIPELINE_CREATE_TIME = datetime.now() | ||
|
||
# experiments | ||
_TEST_EXPERIMENT = "test-experiment" | ||
|
||
_TEST_METADATASTORE = ( | ||
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" | ||
) | ||
_TEST_CONTEXT_ID = _TEST_EXPERIMENT | ||
_TEST_CONTEXT_NAME = f"{_TEST_METADATASTORE}/contexts/{_TEST_CONTEXT_ID}" | ||
|
||
_EXPERIMENT_MOCK = GapicContext( | ||
name=_TEST_CONTEXT_NAME, | ||
schema_title=constants.SYSTEM_EXPERIMENT, | ||
schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT], | ||
metadata={**constants.EXPERIMENT_METADATA}, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_pipeline_service_create(): | ||
|
@@ -303,6 +323,90 @@ def mock_request_urlopen(job_spec): | |
yield mock_urlopen | ||
|
||
|
||
# experiment mocks | ||
@pytest.fixture | ||
def get_metadata_store_mock(): | ||
with patch.object( | ||
MetadataServiceClient, "get_metadata_store" | ||
) as get_metadata_store_mock: | ||
get_metadata_store_mock.return_value = GapicMetadataStore( | ||
name=_TEST_METADATASTORE, | ||
) | ||
yield get_metadata_store_mock | ||
|
||
|
||
@pytest.fixture | ||
def get_experiment_mock(): | ||
with patch.object(MetadataServiceClient, "get_context") as get_context_mock: | ||
get_context_mock.return_value = _EXPERIMENT_MOCK | ||
yield get_context_mock | ||
|
||
|
||
@pytest.fixture | ||
def add_context_children_mock(): | ||
with patch.object( | ||
MetadataServiceClient, "add_context_children" | ||
) as add_context_children_mock: | ||
yield add_context_children_mock | ||
|
||
|
||
@pytest.fixture | ||
def list_contexts_mock(): | ||
with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock: | ||
list_contexts_mock.return_value = [_EXPERIMENT_MOCK] | ||
yield list_contexts_mock | ||
|
||
|
||
@pytest.fixture | ||
def create_experiment_run_context_mock(): | ||
with patch.object(MetadataServiceClient, "create_context") as create_context_mock: | ||
create_context_mock.side_effect = [_EXPERIMENT_MOCK] | ||
yield create_context_mock | ||
|
||
|
||
def make_pipeline_job_with_experiment(state): | ||
return gca_pipeline_job.PipelineJob( | ||
name=_TEST_PIPELINE_JOB_NAME, | ||
state=state, | ||
create_time=_TEST_PIPELINE_CREATE_TIME, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
network=_TEST_NETWORK, | ||
job_detail=gca_pipeline_job.PipelineJobDetail( | ||
pipeline_run_context=gca_context.Context( | ||
name=_TEST_PIPELINE_JOB_NAME, | ||
parent_contexts=[_TEST_CONTEXT_NAME], | ||
), | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_create_pipeline_job_with_experiment(): | ||
with mock.patch.object( | ||
pipeline_service_client.PipelineServiceClient, "create_pipeline_job" | ||
) as mock_pipeline_with_experiment: | ||
mock_pipeline_with_experiment.return_value = make_pipeline_job_with_experiment( | ||
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED | ||
) | ||
yield mock_pipeline_with_experiment | ||
|
||
|
||
@pytest.fixture | ||
def mock_get_pipeline_job_with_experiment(): | ||
with mock.patch.object( | ||
pipeline_service_client.PipelineServiceClient, "get_pipeline_job" | ||
) as mock_pipeline_with_experiment: | ||
mock_pipeline_with_experiment.side_effect = [ | ||
make_pipeline_job_with_experiment( | ||
gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING | ||
), | ||
make_pipeline_job_with_experiment( | ||
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED | ||
), | ||
] | ||
yield mock_pipeline_with_experiment | ||
|
||
|
||
@pytest.mark.usefixtures("google_auth_mock") | ||
class TestPipelineJob: | ||
def setup_method(self): | ||
|
@@ -1384,3 +1488,90 @@ def test_clone_pipeline_job_with_all_args( | |
assert cloned._gca_resource == make_pipeline_job( | ||
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"job_spec", | ||
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], | ||
) | ||
def test_get_associated_experiment_from_pipeline_returns_none_without_experiment( | ||
self, | ||
mock_pipeline_service_create, | ||
mock_pipeline_service_get, | ||
job_spec, | ||
mock_load_yaml_and_json, | ||
): | ||
aiplatform.init( | ||
project=_TEST_PROJECT, | ||
staging_bucket=_TEST_GCS_BUCKET_NAME, | ||
location=_TEST_LOCATION, | ||
credentials=_TEST_CREDENTIALS, | ||
) | ||
|
||
job = pipeline_jobs.PipelineJob( | ||
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, | ||
template_path=_TEST_TEMPLATE_PATH, | ||
job_id=_TEST_PIPELINE_JOB_ID, | ||
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, | ||
enable_caching=True, | ||
) | ||
|
||
job.submit( | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
network=_TEST_NETWORK, | ||
create_request_timeout=None, | ||
) | ||
|
||
job.wait() | ||
|
||
test_experiment = job.get_associated_experiment() | ||
|
||
assert test_experiment is None | ||
|
||
@pytest.mark.parametrize( | ||
"job_spec", | ||
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], | ||
) | ||
def test_get_associated_experiment_from_pipeline_returns_experiment( | ||
self, | ||
job_spec, | ||
mock_load_yaml_and_json, | ||
add_context_children_mock, | ||
get_experiment_mock, | ||
create_experiment_run_context_mock, | ||
get_metadata_store_mock, | ||
mock_create_pipeline_job_with_experiment, | ||
mock_get_pipeline_job_with_experiment, | ||
): | ||
aiplatform.init( | ||
project=_TEST_PROJECT, | ||
staging_bucket=_TEST_GCS_BUCKET_NAME, | ||
location=_TEST_LOCATION, | ||
credentials=_TEST_CREDENTIALS, | ||
) | ||
|
||
test_experiment = aiplatform.Experiment(_TEST_EXPERIMENT) | ||
|
||
job = pipeline_jobs.PipelineJob( | ||
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, | ||
template_path=_TEST_TEMPLATE_PATH, | ||
job_id=_TEST_PIPELINE_JOB_ID, | ||
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, | ||
enable_caching=True, | ||
) | ||
|
||
assert get_experiment_mock.call_count == 1 | ||
|
||
job.submit( | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
network=_TEST_NETWORK, | ||
create_request_timeout=None, | ||
experiment=test_experiment, | ||
) | ||
|
||
job.wait() | ||
|
||
associated_experiment = job.get_associated_experiment() | ||
|
||
assert associated_experiment.resource_name == _TEST_CONTEXT_NAME | ||
|
||
assert add_context_children_mock.call_count == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can locally filter out the
pipeline_context
to avoid an additional GET: https://github.com/googleapis/python-aiplatform/blob/main/google/cloud/aiplatform_v1/types/pipeline_job.py#L309ie: