Skip to content

Commit

Permalink
fix key to parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-yaqi committed Nov 4, 2021
1 parent 8f50d38 commit 3581bf1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 62 deletions.
56 changes: 24 additions & 32 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ class PipelineRuntimeConfigBuilder(object):
def __init__(
self,
pipeline_root: str,
schema_version: str,
parameter_types: Mapping[str, str],
parameter_values: Optional[Dict[str, Any]] = None,
schema_version: Optional[str] = None,
):
"""Creates a PipelineRuntimeConfigBuilder object.
Args:
pipeline_root (str):
Required. The root of the pipeline outputs.
schema_version (str):
Required. Schema version of the IR. This field determines the fields supported in current version of IR.
parameter_types (Mapping[str, str]):
Required. The mapping from pipeline parameter name to its type.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter name to its value.
schema_version (str):
Optional. Schema version of the IR. This field determines the fields supported in current version of IR.
"""
self._pipeline_root = pipeline_root
self._schema_version = schema_version
self._parameter_types = parameter_types
self._parameter_values = copy.deepcopy(parameter_values or {})
self._schema_version = schema_version or "2.0.0"

@classmethod
def from_job_spec_json(
Expand All @@ -64,15 +64,12 @@ def from_job_spec_json(
A PipelineRuntimeConfigBuilder object.
"""
runtime_config_spec = job_spec["runtimeConfig"]
input_definitions = (
job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {}
)
parameter_input_definitions = (
input_definitions.get("parameterValues")
or input_definitions.get("parameters")
or {}
job_spec["pipelineSpec"]["root"]
.get("inputDefinitions", {})
.get("parameters", {})
)
schema_version = job_spec["pipelineSpec"].get("schemaVersion")
schema_version = job_spec["pipelineSpec"]["schemaVersion"]

# 'type' is deprecated in IR and change to 'parameterType'.
parameter_types = {
Expand All @@ -82,7 +79,7 @@ def from_job_spec_json(

pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
parameter_values = _parse_runtime_parameters(runtime_config_spec)
return cls(pipeline_root, parameter_types, parameter_values, schema_version)
return cls(pipeline_root, schema_version, parameter_types, parameter_values)

def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
"""Updates pipeline_root value.
Expand Down Expand Up @@ -141,7 +138,7 @@ def build(self) -> Dict[str, Any]:

def _get_vertex_value(
self, name: str, value: Union[int, float, str, bool, list, dict]
) -> Dict[str, Any]:
) -> Union[Dict[str, Any], int, float, str, bool, list, dict]:
"""Converts primitive values into Vertex pipeline Value proto message.
Args:
Expand All @@ -166,26 +163,21 @@ def _get_vertex_value(
"pipeline job input definitions.".format(name)
)

result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
elif self._parameter_types[name] == "BOOLEAN":
result["boolValue"] = value
elif self._parameter_types[name] == "NUMBER_DOUBLE":
result["numberValue"] = value
elif self._parameter_types[name] == "NUMBER_INTEGER":
result["numberValue"] = value
elif self._parameter_types[name] == "LIST":
result["listValue"] = value
elif self._parameter_types[name] == "STRUCT":
result["structValue"] = value
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
return value


def _parse_runtime_parameters(
Expand Down
34 changes: 4 additions & 30 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"root": {
"dag": {"tasks": {}},
"inputDefinitions": {
"parameterValues": {
"parameters": {
"string_param": {"parameterType": "STRING"},
"bool_param": {"parameterType": "BOOLEAN"},
"double_param": {"parameterType": "NUMBER_DOUBLE"},
Expand All @@ -98,17 +98,7 @@
"pipelineSpec": _TEST_PIPELINE_SPEC_LEGACY,
}
_TEST_PIPELINE_JOB = {
"runtimeConfig": {
"parameterValues": {
"string_param": "lorem ipsum",
"bool_param": True,
"double_param": 12.34,
"int_param": 5678,
"list_int_param": [123, 456, 789],
"list_string_param": ["lorem", "ipsum"],
"struct_param": {"key1": 12345, "key2": 67890},
},
},
"runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES},
"pipelineSpec": _TEST_PIPELINE_SPEC,
}

Expand Down Expand Up @@ -282,15 +272,7 @@ def test_run_call_pipeline_service_create(

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": {
"bool_param": {"boolValue": True},
"double_param": {"numberValue": 12.34},
"int_param": {"numberValue": 5678},
"list_int_param": {"listValue": [123, 456, 789]},
"list_string_param": {"listValue": ["lorem", "ipsum"]},
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
"string_param": {"stringValue": "hello world"},
},
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand Down Expand Up @@ -425,15 +407,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": {
"bool_param": {"boolValue": True},
"double_param": {"numberValue": 12.34},
"int_param": {"numberValue": 5678},
"list_int_param": {"listValue": [123, 456, 789]},
"list_string_param": {"listValue": ["lorem", "ipsum"]},
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
"string_param": {"stringValue": "hello world"},
},
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand Down

0 comments on commit 3581bf1

Please sign in to comment.