diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index ef23f86275..90e7a8471f 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -1629,6 +1629,7 @@ def run( tensorboard: Optional[str] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> None: """Run this configured CustomJob. @@ -1686,6 +1687,10 @@ def run( will unblock and it will be executed in a concurrent Future. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. """ network = network or initializer.global_config.network @@ -1700,6 +1705,7 @@ def run( tensorboard=tensorboard, sync=sync, create_request_timeout=create_request_timeout, + disable_retries=disable_retries, ) @base.optional_sync() @@ -1715,6 +1721,7 @@ def _run( tensorboard: Optional[str] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> None: """Helper method to ensure network synchronization and to run the configured CustomJob. @@ -1770,6 +1777,10 @@ def _run( will unblock and it will be executed in a concurrent Future. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. """ self.submit( service_account=service_account, @@ -1781,6 +1792,7 @@ def _run( experiment_run=experiment_run, tensorboard=tensorboard, create_request_timeout=create_request_timeout, + disable_retries=disable_retries, ) self._block_until_complete() @@ -1797,6 +1809,7 @@ def submit( experiment_run: Optional[Union["aiplatform.ExperimentRun", str]] = None, tensorboard: Optional[str] = None, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> None: """Submit the configured CustomJob. @@ -1849,6 +1862,10 @@ def submit( https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Raises: ValueError: @@ -1869,11 +1886,12 @@ def submit( if network: self._gca_resource.job_spec.network = network - if timeout or restart_job_on_worker_restart: + if timeout or restart_job_on_worker_restart or disable_retries: timeout = duration_pb2.Duration(seconds=timeout) if timeout else None self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, + disable_retries=disable_retries, ) if enable_web_access: @@ -2287,6 +2305,7 @@ def run( tensorboard: Optional[str] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> None: """Run this configured CustomJob. @@ -2331,6 +2350,10 @@ def run( will unblock and it will be executed in a concurrent Future. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. """ network = network or initializer.global_config.network @@ -2343,6 +2366,7 @@ def run( tensorboard=tensorboard, sync=sync, create_request_timeout=create_request_timeout, + disable_retries=disable_retries, ) @base.optional_sync() @@ -2356,6 +2380,7 @@ def _run( tensorboard: Optional[str] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> None: """Helper method to ensure network synchronization and to run the configured CustomJob. @@ -2398,6 +2423,10 @@ def _run( will unblock and it will be executed in a concurrent Future. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. """ if service_account: self._gca_resource.trial_job_spec.service_account = service_account @@ -2405,12 +2434,13 @@ def _run( if network: self._gca_resource.trial_job_spec.network = network - if timeout or restart_job_on_worker_restart: + if timeout or restart_job_on_worker_restart or disable_retries: duration = duration_pb2.Duration(seconds=timeout) if timeout else None self._gca_resource.trial_job_spec.scheduling = ( gca_custom_job_compat.Scheduling( timeout=duration, restart_job_on_worker_restart=restart_job_on_worker_restart, + disable_retries=disable_retries, ) ) diff --git a/google/cloud/aiplatform/preview/jobs.py b/google/cloud/aiplatform/preview/jobs.py index 35e611f802..7ba408db95 100644 --- a/google/cloud/aiplatform/preview/jobs.py +++ b/google/cloud/aiplatform/preview/jobs.py @@ -238,6 +238,7 @@ def submit( experiment_run: Optional[Union["aiplatform.ExperimentRun", str]] = None, tensorboard: Optional[str] = None, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> None: """Submit the configured CustomJob. @@ -290,6 +291,10 @@ def submit( https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Raises: ValueError: @@ -310,11 +315,12 @@ def submit( if network: self._gca_resource.job_spec.network = network - if timeout or restart_job_on_worker_restart: + if timeout or restart_job_on_worker_restart or disable_retries: timeout = duration_pb2.Duration(seconds=timeout) if timeout else None self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, + disable_retries=disable_retries, ) if enable_web_access: diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 0cfb28c462..7af003d185 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1488,6 +1488,7 @@ def _prepare_training_task_inputs_and_output_dir( enable_web_access: bool = False, enable_dashboard_access: bool = False, tensorboard: Optional[str] = None, + disable_retries: bool = False, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1534,6 +1535,10 @@ def _prepare_training_task_inputs_and_output_dir( `service_account` is required with provided `tensorboard`. For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: Training task inputs and Output directory for custom job. """ @@ -1561,11 +1566,12 @@ def _prepare_training_task_inputs_and_output_dir( if enable_dashboard_access: training_task_inputs["enable_dashboard_access"] = enable_dashboard_access - if timeout or restart_job_on_worker_restart: + if timeout or restart_job_on_worker_restart or disable_retries: timeout = f"{timeout}s" if timeout else None scheduling = { "timeout": timeout, "restart_job_on_worker_restart": restart_job_on_worker_restart, + "disable_retries": disable_retries, } training_task_inputs["scheduling"] = scheduling @@ -2923,6 +2929,7 @@ def run( tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> Optional[models.Model]: """Runs the custom training job. @@ -3206,6 +3213,10 @@ def run( Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -3266,6 +3277,7 @@ def run( else None, sync=sync, create_request_timeout=create_request_timeout, + disable_retries=disable_retries, ) def submit( @@ -3316,6 +3328,7 @@ def submit( tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -3599,6 +3612,10 @@ def submit( Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -3660,6 +3677,7 @@ def submit( sync=sync, create_request_timeout=create_request_timeout, block=False, + disable_retries=disable_retries, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -3705,6 +3723,7 @@ def _run( sync=True, create_request_timeout: Optional[float] = None, block: Optional[bool] = True, + disable_retries: bool = False, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -3890,6 +3909,10 @@ def _run( Optional. The timeout for the create request in seconds block (bool): Optional. If True, block until complete. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -3942,6 +3965,7 @@ def _run( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, + disable_retries=disable_retries, ) model = self._run_job( @@ -4263,6 +4287,7 @@ def run( tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> Optional[models.Model]: """Runs the custom training job. @@ -4539,6 +4564,10 @@ def run( be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4598,6 +4627,7 @@ def run( else None, sync=sync, create_request_timeout=create_request_timeout, + disable_retries=disable_retries, ) def submit( @@ -4648,6 +4678,7 @@ def submit( tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -4924,6 +4955,10 @@ def submit( be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4984,6 +5019,7 @@ def submit( sync=sync, create_request_timeout=create_request_timeout, block=False, + disable_retries=disable_retries, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -5028,6 +5064,7 @@ def _run( sync=True, create_request_timeout: Optional[float] = None, block: Optional[bool] = True, + disable_retries: bool = False, ) -> Optional[models.Model]: """Packages local script and launches training_job. Args: @@ -5209,6 +5246,10 @@ def _run( Optional. The timeout for the create request in seconds. block (bool): Optional. If True, block until complete. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -5255,6 +5296,7 @@ def _run( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, + disable_retries=disable_retries, ) model = self._run_job( @@ -7172,6 +7214,7 @@ def run( tensorboard: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> Optional[models.Model]: """Runs the custom training job. @@ -7448,6 +7491,10 @@ def run( be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -7502,6 +7549,7 @@ def run( else None, sync=sync, create_request_timeout=create_request_timeout, + disable_retries=disable_retries, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -7545,6 +7593,7 @@ def _run( reduction_server_container_uri: Optional[str] = None, sync=True, create_request_timeout: Optional[float] = None, + disable_retries: bool = False, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -7711,6 +7760,10 @@ def _run( be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -7757,6 +7810,7 @@ def _run( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, + disable_retries=disable_retries, ) model = self._run_job( diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py index f6700fa099..20b49999d3 100644 --- a/tests/system/aiplatform/test_e2e_tabular.py +++ b/tests/system/aiplatform/test_e2e_tabular.py @@ -106,6 +106,7 @@ def test_end_to_end_tabular(self, shared_state): enable_web_access=True, sync=False, create_request_timeout=None, + disable_retries=True, ) automl_model = automl_job.run( diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py index 4ca82c7746..80faded765 100644 --- a/tests/unit/aiplatform/constants.py +++ b/tests/unit/aiplatform/constants.py @@ -125,6 +125,7 @@ class TrainingJobConstants: ) _TEST_TIMEOUT = 8000 _TEST_RESTART_JOB_ON_WORKER_RESTART = True + _TEST_DISABLE_RETRIES = True _TEST_BASE_CUSTOM_JOB_PROTO = custom_job.CustomJob( display_name=_TEST_DISPLAY_NAME, @@ -136,6 +137,7 @@ class TrainingJobConstants: scheduling=custom_job.Scheduling( timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT), restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + disable_retries=_TEST_DISABLE_RETRIES, ), service_account=ProjectConstants._TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py index e91f90cefd..ea43c42a4f 100644 --- a/tests/unit/aiplatform/test_custom_job.py +++ b/tests/unit/aiplatform/test_custom_job.py @@ -126,6 +126,7 @@ _TEST_RESTART_JOB_ON_WORKER_RESTART = ( test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART ) +_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES _TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS @@ -421,6 +422,7 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() @@ -465,6 +467,7 @@ def test_submit_custom_job(self, create_custom_job_mock, get_custom_job_mock): timeout=_TEST_TIMEOUT, restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() @@ -516,6 +519,7 @@ def test_submit_custom_job_with_experiments( create_request_timeout=None, experiment=_TEST_EXPERIMENT, experiment_run=_TEST_EXPERIMENT_RUN, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() @@ -569,6 +573,7 @@ def test_create_custom_job_with_timeout( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=180.0, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() @@ -610,6 +615,7 @@ def test_create_custom_job_with_timeout_not_explicitly_set( timeout=_TEST_TIMEOUT, restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() @@ -656,6 +662,7 @@ def test_run_custom_job_with_fail_raises( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait() @@ -696,6 +703,7 @@ def test_run_custom_job_with_fail_at_creation(self): timeout=_TEST_TIMEOUT, restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=False, + disable_retries=_TEST_DISABLE_RETRIES, ) with pytest.raises(RuntimeError) as e: @@ -1012,6 +1020,7 @@ def test_create_custom_job_with_enable_web_access( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() @@ -1083,6 +1092,7 @@ def test_create_custom_job_with_tensorboard( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait() @@ -1149,6 +1159,7 @@ def test_check_custom_job_availability(self): network=_TEST_NETWORK, timeout=_TEST_TIMEOUT, restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() diff --git a/tests/unit/aiplatform/test_custom_job_persistent_resource.py b/tests/unit/aiplatform/test_custom_job_persistent_resource.py index 3405feb9da..3b23c05fcd 100644 --- a/tests/unit/aiplatform/test_custom_job_persistent_resource.py +++ b/tests/unit/aiplatform/test_custom_job_persistent_resource.py @@ -71,6 +71,7 @@ _TEST_RESTART_JOB_ON_WORKER_RESTART = ( test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART ) +_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES _TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS @@ -87,6 +88,7 @@ scheduling=custom_job_v1beta1.Scheduling( timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT), restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + disable_retries=_TEST_DISABLE_RETRIES, ), service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, @@ -175,6 +177,7 @@ def test_create_custom_job_with_persistent_resource( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() @@ -222,6 +225,7 @@ def test_submit_custom_job_with_persistent_resource( timeout=_TEST_TIMEOUT, restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait_for_resource_creation() diff --git a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py index 911115bb59..c625b7b442 100644 --- a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py +++ b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py @@ -66,6 +66,7 @@ _TEST_RESTART_JOB_ON_WORKER_RESTART = ( test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART ) +_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES _TEST_METRIC_SPEC_KEY = "test-metric" _TEST_METRIC_SPEC_VALUE = "maximize" @@ -448,6 +449,7 @@ def test_create_hyperparameter_tuning_job( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait() @@ -519,6 +521,7 @@ def test_create_hyperparameter_tuning_job_with_timeout( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=180.0, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait() @@ -586,6 +589,7 @@ def test_run_hyperparameter_tuning_job_with_fail_raises( restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait() @@ -647,6 +651,7 @@ def test_run_hyperparameter_tuning_job_with_fail_at_creation(self): timeout=_TEST_TIMEOUT, restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=False, + disable_retries=_TEST_DISABLE_RETRIES, ) with pytest.raises(RuntimeError) as e: @@ -783,6 +788,7 @@ def test_create_hyperparameter_tuning_job_with_tensorboard( tensorboard=test_constants.TensorboardConstants._TEST_TENSORBOARD_NAME, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait() @@ -860,6 +866,7 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access( enable_web_access=test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) job.wait() diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 80a5169e8c..a35e644b46 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -233,6 +233,7 @@ test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART ) +_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES _TEST_ENABLE_WEB_ACCESS = test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS _TEST_ENABLE_DASHBOARD_ACCESS = True _TEST_WEB_ACCESS_URIS = test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS @@ -278,6 +279,7 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"): custom_job_proto.job_spec.scheduling.restart_job_on_worker_restart = ( _TEST_RESTART_JOB_ON_WORKER_RESTART ) + custom_job_proto.job_spec.scheduling.disable_retries = _TEST_DISABLE_RETRIES return custom_job_proto @@ -730,6 +732,7 @@ def make_training_pipeline_with_scheduling(state): training_task_inputs={ "timeout": f"{_TEST_TIMEOUT}s", "restart_job_on_worker_restart": _TEST_RESTART_JOB_ON_WORKER_RESTART, + "disable_retries": _TEST_DISABLE_RETRIES, }, ) if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING: @@ -2251,6 +2254,7 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) if not sync: @@ -2269,6 +2273,10 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): job._gca_resource.training_task_inputs["restart_job_on_worker_restart"] == _TEST_RESTART_JOB_ON_WORKER_RESTART ) + assert ( + job._gca_resource.training_task_inputs["disable_retries"] + == _TEST_DISABLE_RETRIES + ) @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @@ -4250,6 +4258,7 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) if not sync: @@ -4268,6 +4277,10 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): job._gca_resource.training_task_inputs["restart_job_on_worker_restart"] == _TEST_RESTART_JOB_ON_WORKER_RESTART ) + assert ( + job._gca_resource.training_task_inputs["disable_retries"] + == _TEST_DISABLE_RETRIES + ) @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @@ -6525,6 +6538,7 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, sync=sync, create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, ) if not sync: @@ -6543,6 +6557,10 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): job._gca_resource.training_task_inputs["restart_job_on_worker_restart"] == _TEST_RESTART_JOB_ON_WORKER_RESTART ) + assert ( + job._gca_resource.training_task_inputs["disable_retries"] + == _TEST_DISABLE_RETRIES + ) @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)