Skip to content

Commit

Permalink
fix key to parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-yaqi committed Nov 4, 2021
1 parent 8f50d38 commit 6c689d2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 56 deletions.
44 changes: 18 additions & 26 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,10 @@ def from_job_spec_json(
A PipelineRuntimeConfigBuilder object.
"""
runtime_config_spec = job_spec["runtimeConfig"]
input_definitions = (
job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {}
)
parameter_input_definitions = (
input_definitions.get("parameterValues")
or input_definitions.get("parameters")
or {}
job_spec["pipelineSpec"]["root"]
.get("inputDefinitions", {})
.get("parameters", {})
)
schema_version = job_spec["pipelineSpec"].get("schemaVersion")

Expand Down Expand Up @@ -141,7 +138,7 @@ def build(self) -> Dict[str, Any]:

def _get_vertex_value(
self, name: str, value: Union[int, float, str, bool, list, dict]
) -> Dict[str, Any]:
) -> Union[Dict[str, Any], int, float, str, bool, list, dict]:
"""Converts primitive values into Vertex pipeline Value proto message.
Args:
Expand All @@ -166,26 +163,21 @@ def _get_vertex_value(
"pipeline job input definitions.".format(name)
)

result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
elif self._parameter_types[name] == "BOOLEAN":
result["boolValue"] = value
elif self._parameter_types[name] == "NUMBER_DOUBLE":
result["numberValue"] = value
elif self._parameter_types[name] == "NUMBER_INTEGER":
result["numberValue"] = value
elif self._parameter_types[name] == "LIST":
result["listValue"] = value
elif self._parameter_types[name] == "STRUCT":
result["structValue"] = value
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
return value


def _parse_runtime_parameters(
Expand Down
34 changes: 4 additions & 30 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {
"parameterValues": {
"parameters": {
"string_param": {"parameterType": "STRING"},
"bool_param": {"parameterType": "BOOLEAN"},
"double_param": {"parameterType": "NUMBER_DOUBLE"},
Expand All @@ -98,17 +98,7 @@
"pipelineSpec": _TEST_PIPELINE_SPEC_LEGACY,
}
_TEST_PIPELINE_JOB = {
"runtimeConfig": {
"parameterValues": {
"string_param": "lorem ipsum",
"bool_param": True,
"double_param": 12.34,
"int_param": 5678,
"list_int_param": [123, 456, 789],
"list_string_param": ["lorem", "ipsum"],
"struct_param": {"key1": 12345, "key2": 67890},
},
},
"runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES},
"pipelineSpec": _TEST_PIPELINE_SPEC,
}

Expand Down Expand Up @@ -282,15 +272,7 @@ def test_run_call_pipeline_service_create(

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": {
"bool_param": {"boolValue": True},
"double_param": {"numberValue": 12.34},
"int_param": {"numberValue": 5678},
"list_int_param": {"listValue": [123, 456, 789]},
"list_string_param": {"listValue": ["lorem", "ipsum"]},
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
"string_param": {"stringValue": "hello world"},
},
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand Down Expand Up @@ -425,15 +407,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": {
"bool_param": {"boolValue": True},
"double_param": {"numberValue": 12.34},
"int_param": {"numberValue": 5678},
"list_int_param": {"listValue": [123, 456, 789]},
"list_string_param": {"listValue": ["lorem", "ipsum"]},
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
"string_param": {"stringValue": "hello world"},
},
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand Down

0 comments on commit 6c689d2

Please sign in to comment.