Skip to content

Commit d83bc79

Browse files
sararobcopybara-github
authored andcommitted
chore: remove test_custom_job dependency from test_hyperparameter_tuning_job
PiperOrigin-RevId: 506077656
1 parent 89968f6 commit d83bc79

File tree

2 files changed

+97
-33
lines changed

2 files changed

+97
-33
lines changed

tests/unit/aiplatform/constants.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import dataclasses
2020

21-
from google.protobuf import timestamp_pb2
21+
from google.protobuf import timestamp_pb2, duration_pb2
2222

2323
from google.cloud.aiplatform.utils import source_utils
2424
from google.cloud.aiplatform import explain
@@ -28,8 +28,10 @@
2828
)
2929

3030
from google.cloud.aiplatform.compat.types import (
31+
custom_job,
3132
encryption_spec,
3233
endpoint,
34+
io,
3335
model,
3436
)
3537

@@ -44,6 +46,9 @@ class ProjectConstants:
4446
_TEST_ENCRYPTION_SPEC = encryption_spec.EncryptionSpec(
4547
kms_key_name=_TEST_ENCRYPTION_KEY_NAME
4648
)
49+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
50+
_TEST_SERVICE_ACCOUNT = "[email protected]"
51+
_TEST_LABELS = {"my_key": "my_value"}
4752

4853

4954
@dataclasses.dataclass(frozen=True)
@@ -70,6 +75,55 @@ class TrainingJobConstants:
7075
_TEST_REDUCTION_SERVER_CONTAINER_URI = (
7176
"us-docker.pkg.dev/vertex-ai-restricted/training/reductionserver:latest"
7277
)
78+
_TEST_STAGING_BUCKET = "gs://test-staging-bucket"
79+
_TEST_DISPLAY_NAME = "my_job_1234"
80+
_TEST_BASE_OUTPUT_DIR = f"{_TEST_STAGING_BUCKET}/{_TEST_DISPLAY_NAME}"
81+
_TEST_ENABLE_WEB_ACCESS = True
82+
_TEST_WEB_ACCESS_URIS = {"workerpool0-0": "uri"}
83+
_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image"
84+
85+
_TEST_RUN_ARGS = ["-v", "0.1", "--test=arg"]
86+
87+
_TEST_WORKER_POOL_SPEC = [
88+
{
89+
"machine_spec": {
90+
"machine_type": "n1-standard-4",
91+
"accelerator_type": "NVIDIA_TESLA_K80",
92+
"accelerator_count": 1,
93+
},
94+
"replica_count": 1,
95+
"disk_spec": {"boot_disk_type": "pd-ssd", "boot_disk_size_gb": 100},
96+
"container_spec": {
97+
"image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
98+
"command": [],
99+
"args": _TEST_RUN_ARGS,
100+
},
101+
}
102+
]
103+
_TEST_ID = "1028944691210842416"
104+
_TEST_NETWORK = (
105+
f"projects/{ProjectConstants._TEST_PROJECT}/global/networks/{_TEST_ID}"
106+
)
107+
_TEST_TIMEOUT = 8000
108+
_TEST_RESTART_JOB_ON_WORKER_RESTART = True
109+
110+
_TEST_BASE_CUSTOM_JOB_PROTO = custom_job.CustomJob(
111+
display_name=_TEST_DISPLAY_NAME,
112+
job_spec=custom_job.CustomJobSpec(
113+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
114+
base_output_directory=io.GcsDestination(
115+
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
116+
),
117+
scheduling=custom_job.Scheduling(
118+
timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
119+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
120+
),
121+
service_account=ProjectConstants._TEST_SERVICE_ACCOUNT,
122+
network=_TEST_NETWORK,
123+
),
124+
labels=ProjectConstants._TEST_LABELS,
125+
encryption_spec=ProjectConstants._TEST_ENCRYPTION_SPEC,
126+
)
73127

74128

75129
@dataclasses.dataclass(frozen=True)
@@ -120,3 +174,11 @@ class EndpointConstants:
120174
endpoint.DeployedModel(id=_TEST_ID_3, display_name=_TEST_DISPLAY_NAME_3),
121175
]
122176
_TEST_TRAFFIC_SPLIT = {_TEST_ID: 0, _TEST_ID_2: 100, _TEST_ID_3: 0}
177+
178+
179+
@dataclasses.dataclass(frozen=True)
180+
class TensorboardConstants:
181+
"""Defines constants used by tests that create Tensorboard resources."""
182+
183+
_TEST_ID = "1028944691210842416"
184+
_TEST_TENSORBOARD_NAME = f"{ProjectConstants._TEST_PARENT}/tensorboards/{_TEST_ID}"

tests/unit/aiplatform/test_hyperparameter_tuning_job.py

+34-32
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from google.cloud.aiplatform.compat.services import job_service_client
3737

38-
import test_custom_job
38+
import constants as test_constants
3939

4040
_TEST_PROJECT = "test-project"
4141
_TEST_LOCATION = "us-central1"
@@ -44,8 +44,8 @@
4444

4545
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
4646

47-
_TEST_STAGING_BUCKET = test_custom_job._TEST_STAGING_BUCKET
48-
_TEST_BASE_OUTPUT_DIR = test_custom_job._TEST_BASE_OUTPUT_DIR
47+
_TEST_STAGING_BUCKET = test_constants.TrainingJobConstants._TEST_STAGING_BUCKET
48+
_TEST_BASE_OUTPUT_DIR = test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR
4949

5050
_TEST_HYPERPARAMETERTUNING_JOB_NAME = (
5151
f"{_TEST_PARENT}/hyperparameterTuningJobs/{_TEST_ID}"
@@ -152,7 +152,7 @@
152152
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
153153
max_trial_count=_TEST_MAX_TRIAL_COUNT,
154154
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
155-
trial_job_spec=test_custom_job._TEST_BASE_CUSTOM_JOB_PROTO.job_spec,
155+
trial_job_spec=test_constants.TrainingJobConstants._TEST_BASE_CUSTOM_JOB_PROTO.job_spec,
156156
labels=_TEST_LABELS,
157157
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
158158
)
@@ -176,7 +176,9 @@ def _get_trial_proto(id=None, state=None):
176176
trial_proto.id = id
177177
trial_proto.state = state
178178
if state == gca_study_compat.Trial.State.ACTIVE:
179-
trial_proto.web_access_uris = test_custom_job._TEST_WEB_ACCESS_URIS
179+
trial_proto.web_access_uris = (
180+
test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS
181+
)
180182
return trial_proto
181183

182184

@@ -189,7 +191,7 @@ def _get_hyperparameter_tuning_job_proto_with_enable_web_access(
189191
error=error,
190192
)
191193
hyperparameter_tuning_job_proto.trial_job_spec.enable_web_access = (
192-
test_custom_job._TEST_ENABLE_WEB_ACCESS
194+
test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS
193195
)
194196
if state == gca_job_state_compat.JobState.JOB_STATE_RUNNING:
195197
hyperparameter_tuning_job_proto.trials = trials
@@ -372,7 +374,7 @@ def create_hyperparameter_tuning_job_mock_with_tensorboard():
372374
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
373375
)
374376
hyperparameter_tuning_job_proto.trial_job_spec.tensorboard = (
375-
test_custom_job._TEST_TENSORBOARD_NAME
377+
test_constants.TensorboardConstants._TEST_TENSORBOARD_NAME
376378
)
377379
create_hyperparameter_tuning_job_mock.return_value = (
378380
hyperparameter_tuning_job_proto
@@ -405,9 +407,9 @@ def test_create_hyperparameter_tuning_job(
405407
)
406408

407409
custom_job = aiplatform.CustomJob(
408-
display_name=test_custom_job._TEST_DISPLAY_NAME,
409-
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
410-
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
410+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
411+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
412+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
411413
)
412414

413415
job = aiplatform.HyperparameterTuningJob(
@@ -476,9 +478,9 @@ def test_create_hyperparameter_tuning_job_with_timeout(
476478
)
477479

478480
custom_job = aiplatform.CustomJob(
479-
display_name=test_custom_job._TEST_DISPLAY_NAME,
480-
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
481-
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
481+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
482+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
483+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
482484
)
483485

484486
job = aiplatform.HyperparameterTuningJob(
@@ -542,9 +544,9 @@ def test_run_hyperparameter_tuning_job_with_fail_raises(
542544
)
543545

544546
custom_job = aiplatform.CustomJob(
545-
display_name=test_custom_job._TEST_DISPLAY_NAME,
546-
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
547-
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
547+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
548+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
549+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
548550
)
549551

550552
job = aiplatform.HyperparameterTuningJob(
@@ -606,9 +608,9 @@ def test_run_hyperparameter_tuning_job_with_fail_at_creation(self):
606608
)
607609

608610
custom_job = aiplatform.CustomJob(
609-
display_name=test_custom_job._TEST_DISPLAY_NAME,
610-
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
611-
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
611+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
612+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
613+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
612614
)
613615

614616
job = aiplatform.HyperparameterTuningJob(
@@ -676,9 +678,9 @@ def test_hyperparameter_tuning_job_get_state_raises_without_run(self):
676678
)
677679

678680
custom_job = aiplatform.CustomJob(
679-
display_name=test_custom_job._TEST_DISPLAY_NAME,
680-
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
681-
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
681+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
682+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
683+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
682684
)
683685

684686
job = aiplatform.HyperparameterTuningJob(
@@ -739,9 +741,9 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
739741
)
740742

741743
custom_job = aiplatform.CustomJob(
742-
display_name=test_custom_job._TEST_DISPLAY_NAME,
743-
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
744-
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
744+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
745+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
746+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
745747
)
746748

747749
job = aiplatform.HyperparameterTuningJob(
@@ -776,7 +778,7 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
776778
network=_TEST_NETWORK,
777779
timeout=_TEST_TIMEOUT,
778780
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
779-
tensorboard=test_custom_job._TEST_TENSORBOARD_NAME,
781+
tensorboard=test_constants.TensorboardConstants._TEST_TENSORBOARD_NAME,
780782
sync=sync,
781783
create_request_timeout=None,
782784
)
@@ -785,7 +787,7 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
785787

786788
expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto()
787789
expected_hyperparameter_tuning_job.trial_job_spec.tensorboard = (
788-
test_custom_job._TEST_TENSORBOARD_NAME
790+
test_constants.TensorboardConstants._TEST_TENSORBOARD_NAME
789791
)
790792

791793
create_hyperparameter_tuning_job_mock_with_tensorboard.assert_called_once_with(
@@ -816,9 +818,9 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access(
816818
)
817819

818820
custom_job = aiplatform.CustomJob(
819-
display_name=test_custom_job._TEST_DISPLAY_NAME,
820-
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
821-
base_output_dir=test_custom_job._TEST_BASE_OUTPUT_DIR,
821+
display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME,
822+
worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC,
823+
base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR,
822824
)
823825

824826
job = aiplatform.HyperparameterTuningJob(
@@ -853,7 +855,7 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access(
853855
network=_TEST_NETWORK,
854856
timeout=_TEST_TIMEOUT,
855857
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
856-
enable_web_access=test_custom_job._TEST_ENABLE_WEB_ACCESS,
858+
enable_web_access=test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS,
857859
sync=sync,
858860
create_request_timeout=None,
859861
)
@@ -888,5 +890,5 @@ def test_log_enable_web_access_after_get_hyperparameter_tuning_job(
888890
)
889891
hp_job._block_until_complete()
890892
assert hp_job._logged_web_access_uris == set(
891-
test_custom_job._TEST_WEB_ACCESS_URIS.values()
893+
test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS.values()
892894
)

0 commit comments

Comments
 (0)