Skip to content

Commit

Permalink
feat: expose base_output_dir for custom job
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Aug 2, 2021
1 parent 6a99b12 commit b3f5acc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 7 deletions.
18 changes: 17 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
study as gca_study_compat,
)

from google.cloud.aiplatform.utils import _timestamped_gcs_dir

_LOGGER = base.Logger(__name__)

_JOB_COMPLETE_STATES = (
Expand Down Expand Up @@ -930,6 +932,7 @@ def __init__(
self,
display_name: str,
worker_pool_specs: Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]],
base_output_dir: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand Down Expand Up @@ -977,6 +980,9 @@ def __init__(
worker_pool_specs (Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]]):
Required. The spec of the worker pools including machine type and Docker image.
Can provided as a list of dictionaries or list of WorkerPoolSpec proto messages.
base_output_dir (str):
Optional. GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
project (str):
Optional.Project to run the custom job in. Overrides project set in aiplatform.init.
location (str):
Expand Down Expand Up @@ -1008,12 +1014,17 @@ def __init__(
"should be set using aiplatform.init(staging_bucket='gs://my-bucket')"
)

# default directory if not given
base_output_dir = base_output_dir or _timestamped_gcs_dir(
staging_bucket, "aiplatform-custom-job"
)

self._gca_resource = gca_custom_job_compat.CustomJob(
display_name=display_name,
job_spec=gca_custom_job_compat.CustomJobSpec(
worker_pool_specs=worker_pool_specs,
base_output_directory=gca_io_compat.GcsDestination(
output_uri_prefix=staging_bucket
output_uri_prefix=base_output_dir
),
),
encryption_spec=initializer.global_config.get_encryption_spec(
Expand Down Expand Up @@ -1049,6 +1060,7 @@ def from_local_script(
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
base_output_dir: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand Down Expand Up @@ -1105,6 +1117,9 @@ def from_local_script(
NVIDIA_TESLA_T4
accelerator_count (int):
Optional. The number of accelerators to attach to a worker replica.
base_output_dir (str):
Optional. GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
project (str):
Optional. Project to run the custom job in. Overrides project set in aiplatform.init.
location (str):
Expand Down Expand Up @@ -1170,6 +1185,7 @@ def from_local_script(
return cls(
display_name=display_name,
worker_pool_specs=worker_pool_specs,
base_output_dir=base_output_dir,
project=project,
location=location,
credentials=credentials,
Expand Down
41 changes: 35 additions & 6 deletions tests/unit/aiplatform/test_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
]

_TEST_STAGING_BUCKET = "gs://test-staging-bucket"
_TEST_BASE_OUTPUT_DIR = f"{_TEST_STAGING_BUCKET}/{_TEST_DISPLAY_NAME}"

# CMEK encryption
_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default"
Expand All @@ -91,7 +92,7 @@
job_spec=gca_custom_job_compat.CustomJobSpec(
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_directory=gca_io_compat.GcsDestination(
output_uri_prefix=_TEST_STAGING_BUCKET
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
),
scheduling=gca_custom_job_compat.Scheduling(
timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
Expand Down Expand Up @@ -224,7 +225,9 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
)

job.run(
Expand Down Expand Up @@ -265,7 +268,9 @@ def test_run_custom_job_with_fail_raises(
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
)

with pytest.raises(RuntimeError) as e:
Expand Down Expand Up @@ -306,7 +311,9 @@ def test_run_custom_job_with_fail_at_creation(self):
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
)

job.run(
Expand Down Expand Up @@ -342,7 +349,9 @@ def test_custom_job_get_state_raises_without_run(self):
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
)

with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -385,6 +394,7 @@ def test_create_from_local_script(
display_name=_TEST_DISPLAY_NAME,
script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
)

job.run(sync=sync)
Expand Down Expand Up @@ -428,7 +438,9 @@ def test_create_custom_job_with_tensorboard(
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
)

job.run(
Expand All @@ -454,3 +466,20 @@ def test_create_custom_job_with_tensorboard(
assert (
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
)

def test_create_custom_job_without_base_output_dir(self,):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC,
)

assert job.job_spec.base_output_directory.output_uri_prefix.startswith(
f"{_TEST_STAGING_BUCKET}/aiplatform-custom-job"
)

0 comments on commit b3f5acc

Please sign in to comment.