Skip to content

Commit

Permalink
feat(PipelineJob): allow PipelineSpec as param (#774)
Browse files Browse the repository at this point in the history
* feat(PipelineJob): accept pipelineSpec as param

* edit

* address comments

* address comments
  • Loading branch information
ji-yaqi authored Oct 19, 2021
1 parent 208889b commit f90a1bd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 37 deletions.
47 changes: 31 additions & 16 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def __init__(
display_name (str):
Required. The user-defined name of this Pipeline.
template_path (str):
Required. The path of PipelineJob JSON file. It can be a local path or a
Google Cloud Storage URI. Example: "gs://project.name"
Required. The path of PipelineJob or PipelineSpec JSON file. It
can be a local path or a Google Cloud Storage URI.
Example: "gs://project.name"
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
Expand Down Expand Up @@ -165,14 +166,37 @@ def __init__(
self._parent = initializer.global_config.common_location_path(
project=project, location=location
)
pipeline_job = json_utils.load_json(
pipeline_json = json_utils.load_json(
template_path, self.project, self.credentials
)
pipeline_root = (
pipeline_root
or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
or initializer.global_config.staging_bucket
# Pipeline_json can be either PipelineJob or PipelineSpec.
if pipeline_json.get("pipelineSpec") is not None:
pipeline_job = pipeline_json
pipeline_root = (
pipeline_root
or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
or initializer.global_config.staging_bucket
)
else:
pipeline_job = {
"pipelineSpec": pipeline_json,
"runtimeConfig": {},
}
pipeline_root = (
pipeline_root
or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
or initializer.global_config.staging_bucket
)
builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
pipeline_job
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()

runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)

pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
self.job_id = job_id or "{pipeline_name}-{timestamp}".format(
Expand All @@ -188,15 +212,6 @@ def __init__(
'"[a-z][-a-z0-9]{{0,127}}"'.format(job_id)
)

builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
pipeline_job
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)

if enable_caching is not None:
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)

Expand Down
121 changes: 100 additions & 21 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@
_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"

_TEST_PIPELINE_PARAMETER_VALUES = {"name_param": "hello"}
_TEST_PIPELINE_JOB_SPEC = {
"runtimeConfig": {},
"pipelineSpec": {
"pipelineInfo": {"name": "my-pipeline"},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {"parameters": {"name_param": {"type": "STRING"}}},
},
"components": {},
_TEST_PIPELINE_SPEC = {
"pipelineInfo": {"name": "my-pipeline"},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {"parameters": {"name_param": {"type": "STRING"}}},
},
"components": {},
}
_TEST_PIPELINE_JOB = {
"runtimeConfig": {},
"pipelineSpec": _TEST_PIPELINE_SPEC,
}

_TEST_PIPELINE_GET_METHOD_NAME = "get_fake_pipeline_job"
Expand Down Expand Up @@ -175,10 +176,23 @@ def mock_pipeline_service_list():


@pytest.fixture
def mock_load_json():
with patch.object(storage.Blob, "download_as_bytes") as mock_load_json:
mock_load_json.return_value = json.dumps(_TEST_PIPELINE_JOB_SPEC).encode()
yield mock_load_json
def mock_load_pipeline_job_json():
with patch.object(storage.Blob, "download_as_bytes") as mock_load_pipeline_job_json:
mock_load_pipeline_job_json.return_value = json.dumps(
_TEST_PIPELINE_JOB
).encode()
yield mock_load_pipeline_job_json


@pytest.fixture
def mock_load_pipeline_spec_json():
with patch.object(
storage.Blob, "download_as_bytes"
) as mock_load_pipeline_spec_json:
mock_load_pipeline_spec_json.return_value = json.dumps(
_TEST_PIPELINE_SPEC
).encode()
yield mock_load_pipeline_spec_json


class TestPipelineJob:
Expand All @@ -199,9 +213,68 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@pytest.mark.usefixtures("mock_load_json")
@pytest.mark.usefixtures("mock_load_pipeline_job_json")
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_pipeline_job_create(
self, mock_pipeline_service_create, mock_pipeline_service_get, sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcs_output_directory": _TEST_GCS_BUCKET_NAME,
"parameters": {"name_param": {"stringValue": "hello"}},
}
runtime_config = gca_pipeline_job_v1beta1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1beta1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": _TEST_PIPELINE_JOB["pipelineSpec"]["pipelineInfo"],
"root": _TEST_PIPELINE_JOB["pipelineSpec"]["root"],
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
)

mock_pipeline_service_get.assert_called_with(name=_TEST_PIPELINE_JOB_NAME)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.usefixtures("mock_load_pipeline_spec_json")
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create(
def test_run_call_pipeline_service_pipeline_spec_create(
self, mock_pipeline_service_create, mock_pipeline_service_get, sync,
):
aiplatform.init(
Expand Down Expand Up @@ -238,8 +311,8 @@ def test_run_call_pipeline_service_create(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["pipelineInfo"],
"root": _TEST_PIPELINE_JOB_SPEC["pipelineSpec"]["root"],
"pipelineInfo": _TEST_PIPELINE_JOB["pipelineSpec"]["pipelineInfo"],
"root": _TEST_PIPELINE_JOB["pipelineSpec"]["root"],
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
Expand Down Expand Up @@ -267,7 +340,9 @@ def test_get_pipeline_job(self, mock_pipeline_service_get):
assert isinstance(job, pipeline_jobs.PipelineJob)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_load_pipeline_job_json",
)
def test_cancel_pipeline_job(
self, mock_pipeline_service_cancel,
Expand All @@ -292,7 +367,9 @@ def test_cancel_pipeline_job(
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_load_pipeline_job_json",
)
def test_list_pipeline_job(self, mock_pipeline_service_list):
aiplatform.init(
Expand All @@ -315,7 +392,9 @@ def test_list_pipeline_job(self, mock_pipeline_service_list):
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_load_pipeline_job_json",
)
def test_cancel_pipeline_job_without_running(
self, mock_pipeline_service_cancel,
Expand All @@ -340,7 +419,7 @@ def test_cancel_pipeline_job_without_running(
@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get_with_fail",
"mock_load_json",
"mock_load_pipeline_job_json",
)
@pytest.mark.parametrize("sync", [True, False])
def test_pipeline_failure_raises(self, sync):
Expand Down

0 comments on commit f90a1bd

Please sign in to comment.