35
35
)
36
36
from google .cloud .aiplatform .compat .services import job_service_client
37
37
38
- import test_custom_job
38
+ import constants as test_constants
39
39
40
40
_TEST_PROJECT = "test-project"
41
41
_TEST_LOCATION = "us-central1"
44
44
45
45
_TEST_PARENT = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } "
46
46
47
- _TEST_STAGING_BUCKET = test_custom_job ._TEST_STAGING_BUCKET
48
- _TEST_BASE_OUTPUT_DIR = test_custom_job ._TEST_BASE_OUTPUT_DIR
47
+ _TEST_STAGING_BUCKET = test_constants . TrainingJobConstants ._TEST_STAGING_BUCKET
48
+ _TEST_BASE_OUTPUT_DIR = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR
49
49
50
50
_TEST_HYPERPARAMETERTUNING_JOB_NAME = (
51
51
f"{ _TEST_PARENT } /hyperparameterTuningJobs/{ _TEST_ID } "
152
152
parallel_trial_count = _TEST_PARALLEL_TRIAL_COUNT ,
153
153
max_trial_count = _TEST_MAX_TRIAL_COUNT ,
154
154
max_failed_trial_count = _TEST_MAX_FAILED_TRIAL_COUNT ,
155
- trial_job_spec = test_custom_job ._TEST_BASE_CUSTOM_JOB_PROTO .job_spec ,
155
+ trial_job_spec = test_constants . TrainingJobConstants ._TEST_BASE_CUSTOM_JOB_PROTO .job_spec ,
156
156
labels = _TEST_LABELS ,
157
157
encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
158
158
)
@@ -176,7 +176,9 @@ def _get_trial_proto(id=None, state=None):
176
176
trial_proto .id = id
177
177
trial_proto .state = state
178
178
if state == gca_study_compat .Trial .State .ACTIVE :
179
- trial_proto .web_access_uris = test_custom_job ._TEST_WEB_ACCESS_URIS
179
+ trial_proto .web_access_uris = (
180
+ test_constants .TrainingJobConstants ._TEST_WEB_ACCESS_URIS
181
+ )
180
182
return trial_proto
181
183
182
184
@@ -189,7 +191,7 @@ def _get_hyperparameter_tuning_job_proto_with_enable_web_access(
189
191
error = error ,
190
192
)
191
193
hyperparameter_tuning_job_proto .trial_job_spec .enable_web_access = (
192
- test_custom_job ._TEST_ENABLE_WEB_ACCESS
194
+ test_constants . TrainingJobConstants ._TEST_ENABLE_WEB_ACCESS
193
195
)
194
196
if state == gca_job_state_compat .JobState .JOB_STATE_RUNNING :
195
197
hyperparameter_tuning_job_proto .trials = trials
@@ -372,7 +374,7 @@ def create_hyperparameter_tuning_job_mock_with_tensorboard():
372
374
state = gca_job_state_compat .JobState .JOB_STATE_PENDING ,
373
375
)
374
376
hyperparameter_tuning_job_proto .trial_job_spec .tensorboard = (
375
- test_custom_job ._TEST_TENSORBOARD_NAME
377
+ test_constants . TensorboardConstants ._TEST_TENSORBOARD_NAME
376
378
)
377
379
create_hyperparameter_tuning_job_mock .return_value = (
378
380
hyperparameter_tuning_job_proto
@@ -405,9 +407,9 @@ def test_create_hyperparameter_tuning_job(
405
407
)
406
408
407
409
custom_job = aiplatform .CustomJob (
408
- display_name = test_custom_job ._TEST_DISPLAY_NAME ,
409
- worker_pool_specs = test_custom_job ._TEST_WORKER_POOL_SPEC ,
410
- base_output_dir = test_custom_job ._TEST_BASE_OUTPUT_DIR ,
410
+ display_name = test_constants . TrainingJobConstants ._TEST_DISPLAY_NAME ,
411
+ worker_pool_specs = test_constants . TrainingJobConstants ._TEST_WORKER_POOL_SPEC ,
412
+ base_output_dir = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR ,
411
413
)
412
414
413
415
job = aiplatform .HyperparameterTuningJob (
@@ -476,9 +478,9 @@ def test_create_hyperparameter_tuning_job_with_timeout(
476
478
)
477
479
478
480
custom_job = aiplatform .CustomJob (
479
- display_name = test_custom_job ._TEST_DISPLAY_NAME ,
480
- worker_pool_specs = test_custom_job ._TEST_WORKER_POOL_SPEC ,
481
- base_output_dir = test_custom_job ._TEST_BASE_OUTPUT_DIR ,
481
+ display_name = test_constants . TrainingJobConstants ._TEST_DISPLAY_NAME ,
482
+ worker_pool_specs = test_constants . TrainingJobConstants ._TEST_WORKER_POOL_SPEC ,
483
+ base_output_dir = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR ,
482
484
)
483
485
484
486
job = aiplatform .HyperparameterTuningJob (
@@ -542,9 +544,9 @@ def test_run_hyperparameter_tuning_job_with_fail_raises(
542
544
)
543
545
544
546
custom_job = aiplatform .CustomJob (
545
- display_name = test_custom_job ._TEST_DISPLAY_NAME ,
546
- worker_pool_specs = test_custom_job ._TEST_WORKER_POOL_SPEC ,
547
- base_output_dir = test_custom_job ._TEST_BASE_OUTPUT_DIR ,
547
+ display_name = test_constants . TrainingJobConstants ._TEST_DISPLAY_NAME ,
548
+ worker_pool_specs = test_constants . TrainingJobConstants ._TEST_WORKER_POOL_SPEC ,
549
+ base_output_dir = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR ,
548
550
)
549
551
550
552
job = aiplatform .HyperparameterTuningJob (
@@ -606,9 +608,9 @@ def test_run_hyperparameter_tuning_job_with_fail_at_creation(self):
606
608
)
607
609
608
610
custom_job = aiplatform .CustomJob (
609
- display_name = test_custom_job ._TEST_DISPLAY_NAME ,
610
- worker_pool_specs = test_custom_job ._TEST_WORKER_POOL_SPEC ,
611
- base_output_dir = test_custom_job ._TEST_BASE_OUTPUT_DIR ,
611
+ display_name = test_constants . TrainingJobConstants ._TEST_DISPLAY_NAME ,
612
+ worker_pool_specs = test_constants . TrainingJobConstants ._TEST_WORKER_POOL_SPEC ,
613
+ base_output_dir = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR ,
612
614
)
613
615
614
616
job = aiplatform .HyperparameterTuningJob (
@@ -676,9 +678,9 @@ def test_hyperparameter_tuning_job_get_state_raises_without_run(self):
676
678
)
677
679
678
680
custom_job = aiplatform .CustomJob (
679
- display_name = test_custom_job ._TEST_DISPLAY_NAME ,
680
- worker_pool_specs = test_custom_job ._TEST_WORKER_POOL_SPEC ,
681
- base_output_dir = test_custom_job ._TEST_BASE_OUTPUT_DIR ,
681
+ display_name = test_constants . TrainingJobConstants ._TEST_DISPLAY_NAME ,
682
+ worker_pool_specs = test_constants . TrainingJobConstants ._TEST_WORKER_POOL_SPEC ,
683
+ base_output_dir = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR ,
682
684
)
683
685
684
686
job = aiplatform .HyperparameterTuningJob (
@@ -739,9 +741,9 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
739
741
)
740
742
741
743
custom_job = aiplatform .CustomJob (
742
- display_name = test_custom_job ._TEST_DISPLAY_NAME ,
743
- worker_pool_specs = test_custom_job ._TEST_WORKER_POOL_SPEC ,
744
- base_output_dir = test_custom_job ._TEST_BASE_OUTPUT_DIR ,
744
+ display_name = test_constants . TrainingJobConstants ._TEST_DISPLAY_NAME ,
745
+ worker_pool_specs = test_constants . TrainingJobConstants ._TEST_WORKER_POOL_SPEC ,
746
+ base_output_dir = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR ,
745
747
)
746
748
747
749
job = aiplatform .HyperparameterTuningJob (
@@ -776,7 +778,7 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
776
778
network = _TEST_NETWORK ,
777
779
timeout = _TEST_TIMEOUT ,
778
780
restart_job_on_worker_restart = _TEST_RESTART_JOB_ON_WORKER_RESTART ,
779
- tensorboard = test_custom_job ._TEST_TENSORBOARD_NAME ,
781
+ tensorboard = test_constants . TensorboardConstants ._TEST_TENSORBOARD_NAME ,
780
782
sync = sync ,
781
783
create_request_timeout = None ,
782
784
)
@@ -785,7 +787,7 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
785
787
786
788
expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto ()
787
789
expected_hyperparameter_tuning_job .trial_job_spec .tensorboard = (
788
- test_custom_job ._TEST_TENSORBOARD_NAME
790
+ test_constants . TensorboardConstants ._TEST_TENSORBOARD_NAME
789
791
)
790
792
791
793
create_hyperparameter_tuning_job_mock_with_tensorboard .assert_called_once_with (
@@ -816,9 +818,9 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access(
816
818
)
817
819
818
820
custom_job = aiplatform .CustomJob (
819
- display_name = test_custom_job ._TEST_DISPLAY_NAME ,
820
- worker_pool_specs = test_custom_job ._TEST_WORKER_POOL_SPEC ,
821
- base_output_dir = test_custom_job ._TEST_BASE_OUTPUT_DIR ,
821
+ display_name = test_constants . TrainingJobConstants ._TEST_DISPLAY_NAME ,
822
+ worker_pool_specs = test_constants . TrainingJobConstants ._TEST_WORKER_POOL_SPEC ,
823
+ base_output_dir = test_constants . TrainingJobConstants ._TEST_BASE_OUTPUT_DIR ,
822
824
)
823
825
824
826
job = aiplatform .HyperparameterTuningJob (
@@ -853,7 +855,7 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access(
853
855
network = _TEST_NETWORK ,
854
856
timeout = _TEST_TIMEOUT ,
855
857
restart_job_on_worker_restart = _TEST_RESTART_JOB_ON_WORKER_RESTART ,
856
- enable_web_access = test_custom_job ._TEST_ENABLE_WEB_ACCESS ,
858
+ enable_web_access = test_constants . TrainingJobConstants ._TEST_ENABLE_WEB_ACCESS ,
857
859
sync = sync ,
858
860
create_request_timeout = None ,
859
861
)
@@ -888,5 +890,5 @@ def test_log_enable_web_access_after_get_hyperparameter_tuning_job(
888
890
)
889
891
hp_job ._block_until_complete ()
890
892
assert hp_job ._logged_web_access_uris == set (
891
- test_custom_job ._TEST_WEB_ACCESS_URIS .values ()
893
+ test_constants . TrainingJobConstants ._TEST_WEB_ACCESS_URIS .values ()
892
894
)
0 commit comments