Skip to content

Commit

Permalink
fix: show logs when TFX pipelines are submitted (#976)
Browse files Browse the repository at this point in the history
* Fix for showing logs when TFX pipelines are submitted

* Add sdkVersion and TFX spec to tests

* Remove params from tfx pipeline spec

* Add new tests for TFX pipelines

* Update tests after linting
  • Loading branch information
sararob authored Feb 14, 2022
1 parent 1c34154 commit c10923b
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
5 changes: 5 additions & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import datetime
import logging
import time
import re
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -273,6 +274,10 @@ def submit(
if network:
self._gca_resource.network = network

# Prevents logs from being supressed on TFX pipelines
if self._gca_resource.pipeline_spec.get("sdkVersion", "").startswith("tfx"):
_LOGGER.setLevel(logging.INFO)

_LOGGER.log_create_with_lro(self.__class__)

self._gca_resource = self.api_client.create_pipeline_job(
Expand Down
86 changes: 86 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@
"schemaVersion": "2.1.0",
"components": {},
}
_TEST_TFX_PIPELINE_SPEC = {
"pipelineInfo": {"name": "my-pipeline"},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {"parameters": {"string_param": {"type": "STRING"}}},
},
"schemaVersion": "2.0.0",
"sdkVersion": "tfx-1.4.0",
"components": {},
}

_TEST_PIPELINE_JOB_LEGACY = {
"runtimeConfig": {},
Expand All @@ -101,6 +111,10 @@
"runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES},
"pipelineSpec": _TEST_PIPELINE_SPEC,
}
_TEST_PIPELINE_JOB_TFX = {
"runtimeConfig": {},
"pipelineSpec": _TEST_TFX_PIPELINE_SPEC,
}

_TEST_PIPELINE_GET_METHOD_NAME = "get_fake_pipeline_job"
_TEST_PIPELINE_LIST_METHOD_NAME = "list_fake_pipeline_jobs"
Expand Down Expand Up @@ -378,6 +392,78 @@ def test_run_call_pipeline_service_create_legacy(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec_json", [_TEST_TFX_PIPELINE_SPEC, _TEST_PIPELINE_JOB_TFX],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_tfx(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec_json,
mock_load_json,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES_LEGACY,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameters": {"string_param": {"stringValue": "hello"}},
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.0.0",
"sdkVersion": "tfx-1.4.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
)

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
)
Expand Down

0 comments on commit c10923b

Please sign in to comment.