From c10923b47b9b9941d14ae2c5398348d971a23f9d Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Mon, 14 Feb 2022 12:44:11 -0500 Subject: [PATCH] fix: show logs when TFX pipelines are submitted (#976) * Fix for showing logs when TFX pipelines are submitted * Add sdkVersion and TFX spec to tests * Remove params from tfx pipeline spec * Add new tests for TFX pipelines * Update tests after linting --- google/cloud/aiplatform/pipeline_jobs.py | 5 ++ tests/unit/aiplatform/test_pipeline_jobs.py | 86 +++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index c756589513..d91a711c65 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -16,6 +16,7 @@ # import datetime +import logging import time import re from typing import Any, Dict, List, Optional @@ -273,6 +274,10 @@ def submit( if network: self._gca_resource.network = network + # Prevents logs from being supressed on TFX pipelines + if self._gca_resource.pipeline_spec.get("sdkVersion", "").startswith("tfx"): + _LOGGER.setLevel(logging.INFO) + _LOGGER.log_create_with_lro(self.__class__) self._gca_resource = self.api_client.create_pipeline_job( diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index 81f14a7ead..ac2fca8ba8 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -92,6 +92,16 @@ "schemaVersion": "2.1.0", "components": {}, } +_TEST_TFX_PIPELINE_SPEC = { + "pipelineInfo": {"name": "my-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": {"parameters": {"string_param": {"type": "STRING"}}}, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "tfx-1.4.0", + "components": {}, +} _TEST_PIPELINE_JOB_LEGACY = { "runtimeConfig": {}, @@ -101,6 +111,10 @@ "runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES}, "pipelineSpec": _TEST_PIPELINE_SPEC, } +_TEST_PIPELINE_JOB_TFX = { + "runtimeConfig": {}, + "pipelineSpec": _TEST_TFX_PIPELINE_SPEC, +} _TEST_PIPELINE_GET_METHOD_NAME = "get_fake_pipeline_job" _TEST_PIPELINE_LIST_METHOD_NAME = "list_fake_pipeline_jobs" @@ -378,6 +392,78 @@ def test_run_call_pipeline_service_create_legacy( gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED ) + @pytest.mark.parametrize( + "job_spec_json", [_TEST_TFX_PIPELINE_SPEC, _TEST_PIPELINE_JOB_TFX], + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_tfx( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + job_spec_json, + mock_load_json, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES_LEGACY, + enable_caching=True, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync, + ) + + if not sync: + job.wait() + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameters": {"string_param": {"stringValue": "hello"}}, + } + runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec={ + "components": {}, + "pipelineInfo": pipeline_spec["pipelineInfo"], + "root": pipeline_spec["root"], + "schemaVersion": "2.0.0", + "sdkVersion": "tfx-1.4.0", + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=_TEST_PIPELINE_JOB_ID, + ) + + mock_pipeline_service_get.assert_called_with( + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + assert job._gca_resource == make_pipeline_job( + gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @pytest.mark.parametrize( "job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB], )