Skip to content

Commit

Permalink
Added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmkc committed Jan 25, 2022
1 parent 19e3914 commit ae6135d
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 0 deletions.
11 changes: 11 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5432,6 +5432,8 @@ def run(
test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
timestamp_split_column_name: Optional[str] = None,
timeout: Optional[int] = None,
restart_job_on_worker_restart: bool = False,
enable_web_access: bool = False,
tensorboard: Optional[str] = None,
sync=True,
Expand Down Expand Up @@ -5645,6 +5647,13 @@ def run(
that piece is ignored by the pipeline.
Supported only for tabular and time series Datasets.
timeout (int):
The maximum job running time in seconds. The default is 7 days.
restart_job_on_worker_restart (bool):
Restarts the entire CustomJob if a worker
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
enable_web_access (bool):
Whether you want Vertex AI to enable interactive shell access
to training containers.
Expand Down Expand Up @@ -5705,6 +5714,8 @@ def run(
predefined_split_column_name=predefined_split_column_name,
timestamp_split_column_name=timestamp_split_column_name,
bigquery_destination=bigquery_destination,
timeout=timeout,
restart_job_on_worker_restart=restart_job_on_worker_restart,
enable_web_access=enable_web_access,
tensorboard=tensorboard,
reduction_server_container_uri=reduction_server_container_uri
Expand Down
231 changes: 231 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@
_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME
)

_TEST_TIMEOUT = 1000
_TEST_RESTART_JOB_ON_WORKER_RESTART = True

_TEST_ENABLE_WEB_ACCESS = True
_TEST_WEB_ACCESS_URIS = {"workerpool0-0": "uri"}

Expand All @@ -201,6 +205,15 @@ def _get_custom_job_proto_with_enable_web_access(state=None, name=None, version=
custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS
return custom_job_proto

def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"):
custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
custom_job_proto.name = name
custom_job_proto.state = state

custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS
if state == gca_job_state.JobState.JOB_STATE_RUNNING:
custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS
return custom_job_proto

def local_copy_method(path):
shutil.copy(path, ".")
Expand Down Expand Up @@ -306,6 +319,40 @@ def mock_get_backing_custom_job_with_enable_web_access():
yield get_custom_job_mock


@pytest.fixture
def mock_get_backing_custom_job_with_scheduling():
with patch.object(
job_service_client.JobServiceClient, "get_custom_job"
) as get_custom_job_mock:
get_custom_job_mock.side_effect = [
_get_custom_job_proto_with_scheduling(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_PENDING,
),
_get_custom_job_proto_with_scheduling(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_RUNNING,
),
_get_custom_job_proto_with_scheduling(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_RUNNING,
),
_get_custom_job_proto_with_scheduling(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_RUNNING,
),
_get_custom_job_proto_with_scheduling(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
),
_get_custom_job_proto_with_scheduling(
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
),
]
yield get_custom_job_mock


class TestTrainingScriptPythonPackagerHelpers:
def setup_method(self):
importlib.reload(initializer)
Expand Down Expand Up @@ -548,6 +595,22 @@ def make_training_pipeline_with_enable_web_access(state):
return training_pipeline


def make_training_pipeline_with_scheduling(state):
training_pipeline = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
state=state,
training_task_inputs={
"timeout": f"{_TEST_TIMEOUT}s",
"restart_job_on_worker_restart": _TEST_RESTART_JOB_ON_WORKER_RESTART,
},
)
if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING:
training_pipeline.training_task_metadata = {
"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME
}
return training_pipeline


@pytest.fixture
def mock_pipeline_service_get():
with mock.patch.object(
Expand Down Expand Up @@ -619,6 +682,35 @@ def mock_pipeline_service_get_with_enable_web_access():
yield mock_get_training_pipeline


@pytest.fixture
def mock_pipeline_service_get_with_scheduling():
with mock.patch.object(
pipeline_service_client.PipelineServiceClient, "get_training_pipeline"
) as mock_get_training_pipeline:
mock_get_training_pipeline.side_effect = [
make_training_pipeline_with_scheduling(
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING,
),
make_training_pipeline_with_scheduling(
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
),
make_training_pipeline_with_scheduling(
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
),
make_training_pipeline_with_scheduling(
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
),
make_training_pipeline_with_scheduling(
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
),
make_training_pipeline_with_scheduling(
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
),
]

yield mock_get_training_pipeline


@pytest.fixture
def mock_pipeline_service_cancel():
with mock.patch.object(
Expand Down Expand Up @@ -650,6 +742,17 @@ def mock_pipeline_service_create_with_enable_web_access():
yield mock_create_training_pipeline


@pytest.fixture
def mock_pipeline_service_create_with_scheduling():
with mock.patch.object(
pipeline_service_client.PipelineServiceClient, "create_training_pipeline"
) as mock_create_training_pipeline:
mock_create_training_pipeline.return_value = make_training_pipeline_with_scheduling(
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING,
)
yield mock_create_training_pipeline


@pytest.fixture
def mock_pipeline_service_get_with_no_model_to_upload():
with mock.patch.object(
Expand Down Expand Up @@ -1397,6 +1500,49 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create_with_scheduling",
"mock_pipeline_service_get_with_scheduling",
"mock_get_backing_custom_job_with_scheduling",
"mock_python_package_to_gcs",
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):

caplog.set_level(logging.INFO)

aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = training_jobs.CustomTrainingJob(
display_name=_TEST_DISPLAY_NAME,
script_path=_TEST_LOCAL_SCRIPT_FILE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
)

job.run(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
)

if not sync:
job.wait()

print(caplog.text)
assert "workerpool0-0" in caplog.text
assert job._gca_resource == make_training_pipeline_with_scheduling(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create_with_no_model_to_upload",
"mock_pipeline_service_get_with_no_model_to_upload",
Expand Down Expand Up @@ -2803,6 +2949,48 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create_with_scheduling",
"mock_pipeline_service_get_with_scheduling",
"mock_get_backing_custom_job_with_scheduling",
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):

caplog.set_level(logging.INFO)

aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = training_jobs.CustomContainerTrainingJob(
display_name=_TEST_DISPLAY_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
command=_TEST_TRAINING_CONTAINER_CMD,
)

job.run(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
)

if not sync:
job.wait()

print(caplog.text)
assert "workerpool0-0" in caplog.text
assert job._gca_resource == make_training_pipeline_with_enable_web_access(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_returns_none_if_no_model_to_upload(
self,
Expand Down Expand Up @@ -4481,6 +4669,49 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create_with_scheduling",
"mock_pipeline_service_get_with_scheduling",
"mock_get_backing_custom_job_with_scheduling",
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):

caplog.set_level(logging.INFO)

aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = training_jobs.CustomPythonPackageTrainingJob(
display_name=_TEST_DISPLAY_NAME,
python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH,
python_module_name=_TEST_PYTHON_MODULE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
)

job.run(
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
sync=sync,
)

if not sync:
job.wait()

print(caplog.text)
assert "workerpool0-0" in caplog.text
assert job._gca_resource == make_training_pipeline_with_enable_web_access(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create_with_no_model_to_upload",
"mock_pipeline_service_get_with_no_model_to_upload",
Expand Down

0 comments on commit ae6135d

Please sign in to comment.