Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Jan 25, 2022
1 parent ae6135d commit 6fefbca
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,8 @@ def run(
test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
timestamp_split_column_name=timestamp_split_column_name,
timeout=timeout,
restart_job_on_worker_restart=restart_job_on_worker_restart,
enable_web_access=enable_web_access,
tensorboard=tensorboard,
reduction_server_container_uri=reduction_server_container_uri
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _get_custom_job_proto_with_enable_web_access(state=None, name=None, version=
custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS
return custom_job_proto


def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"):
custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
custom_job_proto.name = name
Expand All @@ -215,6 +216,7 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"):
custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS
return custom_job_proto


def local_copy_method(path):
shutil.copy(path, ".")
return pathlib.Path(path).name
Expand Down Expand Up @@ -600,6 +602,7 @@ def make_training_pipeline_with_scheduling(state):
name=_TEST_PIPELINE_RESOURCE_NAME,
state=state,
training_task_inputs={
# "enable_web_access": _TEST_ENABLE_WEB_ACCESS,
"timeout": f"{_TEST_TIMEOUT}s",
"restart_job_on_worker_restart": _TEST_RESTART_JOB_ON_WORKER_RESTART,
},
Expand Down Expand Up @@ -1537,8 +1540,6 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
if not sync:
job.wait()

print(caplog.text)
assert "workerpool0-0" in caplog.text
assert job._gca_resource == make_training_pipeline_with_scheduling(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)
Expand Down Expand Up @@ -2985,8 +2986,6 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
if not sync:
job.wait()

print(caplog.text)
assert "workerpool0-0" in caplog.text
assert job._gca_resource == make_training_pipeline_with_enable_web_access(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)
Expand Down Expand Up @@ -4706,8 +4705,6 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
if not sync:
job.wait()

print(caplog.text)
assert "workerpool0-0" in caplog.text
assert job._gca_resource == make_training_pipeline_with_enable_web_access(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)
Expand Down

0 comments on commit 6fefbca

Please sign in to comment.