Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-yaqi committed Nov 3, 2021
1 parent 41e9fe3 commit 8f50d38
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 59 deletions.
33 changes: 22 additions & 11 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, Mapping, Optional, Union
import packaging.version


class PipelineRuntimeConfigBuilder(object):
"""Pipeline RuntimeConfig builder.
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
self._pipeline_root = pipeline_root
self._parameter_types = parameter_types
self._parameter_values = copy.deepcopy(parameter_values or {})
self._schema_version = schema_version or '2.0.0'
self._schema_version = schema_version or "2.0.0"

@classmethod
def from_job_spec_json(
Expand All @@ -63,9 +64,15 @@ def from_job_spec_json(
A PipelineRuntimeConfigBuilder object.
"""
runtime_config_spec = job_spec["runtimeConfig"]
input_definitions = job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {}
parameter_input_definitions = input_definitions.get("parameter_values") or input_definitions.get("parameters") or {}
schema_version = job_spec.get('schemaVersion')
input_definitions = (
job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {}
)
parameter_input_definitions = (
input_definitions.get("parameterValues")
or input_definitions.get("parameters")
or {}
)
schema_version = job_spec["pipelineSpec"].get("schemaVersion")

# 'type' is deprecated in IR and change to 'parameterType'.
parameter_types = {
Expand Down Expand Up @@ -98,9 +105,12 @@ def update_runtime_parameters(
"""
if parameter_values:
parameters = dict(parameter_values)
for k, v in parameter_values.items():
if isinstance(v, (dict, list, bool)):
parameters[k] = json.dumps(v)
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
for k, v in parameter_values.items():
if isinstance(v, (dict, list, bool)):
parameters[k] = json.dumps(v)
self._parameter_values.update(parameters)

def build(self) -> Dict[str, Any]:
Expand All @@ -114,10 +124,12 @@ def build(self) -> Dict[str, Any]:
"Pipeline root must be specified, either during "
"compile time, or when calling the service."
)
if packaging.version.parse(self._schema_version) >= packaging.version.parse("2.1.0"):
parameter_values_key = 'parameter_values'
if packaging.version.parse(self._schema_version) > packaging.version.parse(
"2.0.0"
):
parameter_values_key = "parameterValues"
else:
parameter_values_key = 'parameters'
parameter_values_key = "parameters"
return {
"gcsOutputDirectory": self._pipeline_root,
parameter_values_key: {
Expand Down Expand Up @@ -173,7 +185,6 @@ def _get_vertex_value(
result["structValue"] = value
else:
raise TypeError("Got unknown type of value: {}".format(value))

return result


Expand Down
232 changes: 184 additions & 48 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,50 @@

_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"

_TEST_PIPELINE_PARAMETER_VALUES = {"string_param": "hello"}
_TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"}
_TEST_PIPELINE_PARAMETER_VALUES = {
"string_param": "hello world",
"bool_param": True,
"double_param": 12.34,
"int_param": 5678,
"list_int_param": [123, 456, 789],
"list_string_param": ["lorem", "ipsum"],
"struct_param": {"key1": 12345, "key2": 67890},
}

_TEST_PIPELINE_SPEC_LEGACY = {
"pipelineInfo": {"name": "my-pipeline"},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {"parameters": {"string_param": {"type": "STRING"}}},
},
"schema_version": "2.0.0",
"schemaVersion": "2.0.0",
"components": {},
}
_TEST_PIPELINE_SPEC = {
"pipelineInfo": {"name": "my-pipeline"},
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {
"parameter_values": {
"parameterValues": {
"string_param": {"parameterType": "STRING"},
"bool_param": {
"parameterType": "BOOLEAN"
},
"double_param": {
"parameterType": "NUMBER_DOUBLE"
},
"int_param": {
"parameterType": "NUMBER_INTEGER"
},
"list_int_param": {
"parameterType": "LIST"
},
"list_string_param": {
"parameterType": "LIST"
},
"struct_param": {
"parameterType": "STRUCT"
}
"bool_param": {"parameterType": "BOOLEAN"},
"double_param": {"parameterType": "NUMBER_DOUBLE"},
"int_param": {"parameterType": "NUMBER_INTEGER"},
"list_int_param": {"parameterType": "LIST"},
"list_string_param": {"parameterType": "LIST"},
"struct_param": {"parameterType": "STRUCT"},
}
},
},
"schema_version":"2.1.0",
"schemaVersion": "2.1.0",
"components": {},
}

_TEST_PIPELINE_JOB_LEGACY = {
"runtimeConfig": {},
"pipelineSpec": _TEST_PIPELINE_SPEC_LEGACY,
}

_TEST_PIPELINE_JOB = {
"runtimeConfig": {
"parameterValues": {
Expand All @@ -109,7 +106,7 @@
"int_param": 5678,
"list_int_param": [123, 456, 789],
"list_string_param": ["lorem", "ipsum"],
"struct_param": { "key1": 12345, "key2": 67890}
"struct_param": {"key1": 12345, "key2": 67890},
},
},
"pipelineSpec": _TEST_PIPELINE_SPEC,
Expand Down Expand Up @@ -250,13 +247,7 @@ def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@pytest.mark.parametrize(
"job_spec_json",
[
_TEST_PIPELINE_SPEC,
_TEST_PIPELINE_JOB,
_TEST_PIPELINE_SPEC_LEGACY,
_TEST_PIPELINE_JOB_LEGACY,
],
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create(
Expand Down Expand Up @@ -291,7 +282,15 @@ def test_run_call_pipeline_service_create(

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameter_values": {"string_param": {"stringValue": "hello"}},
"parameterValues": {
"bool_param": {"boolValue": True},
"double_param": {"numberValue": 12.34},
"int_param": {"numberValue": 5678},
"list_int_param": {"listValue": [123, 456, 789]},
"list_string_param": {"listValue": ["lorem", "ipsum"]},
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
"string_param": {"stringValue": "hello world"},
},
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand All @@ -305,6 +304,7 @@ def test_run_call_pipeline_service_create(
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
Expand All @@ -326,13 +326,78 @@ def test_run_call_pipeline_service_create(
)

@pytest.mark.parametrize(
"job_spec_json",
[
_TEST_PIPELINE_SPEC,
_TEST_PIPELINE_JOB,
_TEST_PIPELINE_SPEC_LEGACY,
_TEST_PIPELINE_JOB_LEGACY,
],
"job_spec_json", [_TEST_PIPELINE_SPEC_LEGACY, _TEST_PIPELINE_JOB_LEGACY],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_legacy(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec_json,
mock_load_json,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES_LEGACY,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameters": {"string_param": {"stringValue": "hello"}},
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.0.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
)

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
)
def test_submit_call_pipeline_service_pipeline_job_create(
self,
Expand All @@ -359,8 +424,84 @@ def test_submit_call_pipeline_service_pipeline_job_create(
job.submit(service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK)

expected_runtime_config_dict = {
"gcs_output_directory": _TEST_GCS_BUCKET_NAME,
"parameter_values": {"string_param": {"stringValue": "hello"}},
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": {
"bool_param": {"boolValue": True},
"double_param": {"numberValue": 12.34},
"int_param": {"numberValue": 5678},
"list_int_param": {"listValue": [123, 456, 789]},
"list_string_param": {"listValue": ["lorem", "ipsum"]},
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
"string_param": {"stringValue": "hello world"},
},
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
)

assert not mock_pipeline_service_get.called

job.wait()

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec_json", [_TEST_PIPELINE_SPEC_LEGACY, _TEST_PIPELINE_JOB_LEGACY],
)
def test_submit_call_pipeline_service_pipeline_job_create_legacy(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec_json,
mock_load_json,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES_LEGACY,
enable_caching=True,
)

job.submit(service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK)

expected_runtime_config_dict = {
"parameters": {"string_param": {"stringValue": "hello"}},
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand All @@ -374,6 +515,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.0.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
Expand Down Expand Up @@ -508,13 +650,7 @@ def test_cancel_pipeline_job_without_running(
"mock_pipeline_service_create", "mock_pipeline_service_get_with_fail",
)
@pytest.mark.parametrize(
"job_spec_json",
[
_TEST_PIPELINE_SPEC,
_TEST_PIPELINE_JOB,
_TEST_PIPELINE_SPEC_LEGACY,
_TEST_PIPELINE_JOB_LEGACY,
],
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
def test_pipeline_failure_raises(self, mock_load_json, sync):
Expand Down

0 comments on commit 8f50d38

Please sign in to comment.