From 7b9db49a855fb6f3f356032afbab5ccc5dc3c3e8 Mon Sep 17 00:00:00 2001 From: ivanmkc Date: Mon, 14 Feb 2022 14:00:41 -0500 Subject: [PATCH] Simplfied tests --- tests/unit/aiplatform/test_training_jobs.py | 55 +++------------------ 1 file changed, 8 insertions(+), 47 deletions(-) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index e14c3b4684..6486e3b15e 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -69,7 +69,7 @@ from google.cloud import storage from google.protobuf import json_format from google.protobuf import struct_pb2 - +from google.protobuf import duration_pb2 # type: ignore _TEST_BUCKET_NAME = "test-bucket" _TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder" @@ -211,9 +211,13 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"): custom_job_proto.name = name custom_job_proto.state = state - custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS - if state == gca_job_state.JobState.JOB_STATE_RUNNING: - custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS + custom_job_proto.job_spec.scheduling.timeout = duration_pb2.Duration( + seconds=_TEST_TIMEOUT + ) + custom_job_proto.job_spec.scheduling.restart_job_on_worker_restart = ( + _TEST_RESTART_JOB_ON_WORKER_RESTART + ) + return custom_job_proto @@ -321,40 +325,6 @@ def mock_get_backing_custom_job_with_enable_web_access(): yield get_custom_job_mock -@pytest.fixture -def mock_get_backing_custom_job_with_scheduling(): - with patch.object( - job_service_client.JobServiceClient, "get_custom_job" - ) as get_custom_job_mock: - get_custom_job_mock.side_effect = [ - _get_custom_job_proto_with_scheduling( - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, - state=gca_job_state.JobState.JOB_STATE_PENDING, - ), - _get_custom_job_proto_with_scheduling( - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, - state=gca_job_state.JobState.JOB_STATE_RUNNING, - ), - _get_custom_job_proto_with_scheduling( - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, - state=gca_job_state.JobState.JOB_STATE_RUNNING, - ), - _get_custom_job_proto_with_scheduling( - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, - state=gca_job_state.JobState.JOB_STATE_RUNNING, - ), - _get_custom_job_proto_with_scheduling( - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, - state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, - ), - _get_custom_job_proto_with_scheduling( - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, - state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, - ), - ] - yield get_custom_job_mock - - class TestTrainingScriptPythonPackagerHelpers: def setup_method(self): importlib.reload(initializer) @@ -1505,14 +1475,11 @@ def test_run_call_pipeline_service_create_with_enable_web_access( @pytest.mark.usefixtures( "mock_pipeline_service_create_with_scheduling", "mock_pipeline_service_get_with_scheduling", - "mock_get_backing_custom_job_with_scheduling", "mock_python_package_to_gcs", ) @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): - caplog.set_level(logging.INFO) - aiplatform.init( project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, @@ -2952,13 +2919,10 @@ def test_run_call_pipeline_service_create_with_enable_web_access( @pytest.mark.usefixtures( "mock_pipeline_service_create_with_scheduling", "mock_pipeline_service_get_with_scheduling", - "mock_get_backing_custom_job_with_scheduling", ) @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): - caplog.set_level(logging.INFO) - aiplatform.init( project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, @@ -4670,13 +4634,10 @@ def test_run_call_pipeline_service_create_with_enable_web_access( @pytest.mark.usefixtures( "mock_pipeline_service_create_with_scheduling", "mock_pipeline_service_get_with_scheduling", - "mock_get_backing_custom_job_with_scheduling", ) @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): - caplog.set_level(logging.INFO) - aiplatform.init( project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME,