From 6fefbcae4e19491d597d2f14b8454e8a2808c5d5 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Tue, 25 Jan 2022 16:12:28 -0500 Subject: [PATCH] Fixed tests --- google/cloud/aiplatform/training_jobs.py | 2 ++ tests/unit/aiplatform/test_training_jobs.py | 9 +++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 90dc8afeb91..8a2986fe652 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -2106,6 +2106,8 @@ def run( test_filter_split=test_filter_split, predefined_split_column_name=predefined_split_column_name, timestamp_split_column_name=timestamp_split_column_name, + timeout=timeout, + restart_job_on_worker_restart=restart_job_on_worker_restart, enable_web_access=enable_web_access, tensorboard=tensorboard, reduction_server_container_uri=reduction_server_container_uri diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index d60fc5c93fe..648086c56ae 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -205,6 +205,7 @@ def _get_custom_job_proto_with_enable_web_access(state=None, name=None, version= custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS return custom_job_proto + def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"): custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) custom_job_proto.name = name @@ -215,6 +216,7 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"): custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS return custom_job_proto + def local_copy_method(path): shutil.copy(path, ".") return pathlib.Path(path).name @@ -600,6 +602,7 @@ def make_training_pipeline_with_scheduling(state): name=_TEST_PIPELINE_RESOURCE_NAME, state=state, training_task_inputs={ + # "enable_web_access": _TEST_ENABLE_WEB_ACCESS, "timeout": f"{_TEST_TIMEOUT}s", "restart_job_on_worker_restart": _TEST_RESTART_JOB_ON_WORKER_RESTART, }, @@ -1537,8 +1540,6 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): if not sync: job.wait() - print(caplog.text) - assert "workerpool0-0" in caplog.text assert job._gca_resource == make_training_pipeline_with_scheduling( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) @@ -2985,8 +2986,6 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): if not sync: job.wait() - print(caplog.text) - assert "workerpool0-0" in caplog.text assert job._gca_resource == make_training_pipeline_with_enable_web_access( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) @@ -4706,8 +4705,6 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): if not sync: job.wait() - print(caplog.text) - assert "workerpool0-0" in caplog.text assert job._gca_resource == make_training_pipeline_with_enable_web_access( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED )