Skip to content

Commit

Permalink
feat: support new protobuf value param types for Pipeline Job client (#…
Browse files Browse the repository at this point in the history
…797)

* feat: update PipelineJob to accept protobuf value

* fix tests

* address comments
  • Loading branch information
ji-yaqi authored Oct 28, 2021
1 parent 7ab05d5 commit 2fc05ca
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 133 deletions.
54 changes: 34 additions & 20 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ def from_job_spec_json(
.get("inputDefinitions", {})
.get("parameters", {})
)
parameter_types = {k: v["type"] for k, v in parameter_input_definitions.items()}
# 'type' is deprecated in IR and change to 'parameterType'.
parameter_types = {
k: v.get("parameterType") or v.get("type")
for k, v in parameter_input_definitions.items()
}

pipeline_root = runtime_config_spec.get("gcs_output_directory")
pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
parameter_values = _parse_runtime_parameters(runtime_config_spec)
return cls(pipeline_root, parameter_types, parameter_values)

Expand Down Expand Up @@ -108,7 +112,7 @@ def build(self) -> Dict[str, Any]:
"compile time, or when calling the service."
)
return {
"gcs_output_directory": self._pipeline_root,
"gcsOutputDirectory": self._pipeline_root,
"parameters": {
k: self._get_vertex_value(k, v)
for k, v in self._parameter_values.items()
Expand All @@ -117,14 +121,14 @@ def build(self) -> Dict[str, Any]:
}

def _get_vertex_value(
self, name: str, value: Union[int, float, str]
self, name: str, value: Union[int, float, str, bool, list, dict]
) -> Dict[str, Any]:
"""Converts primitive values into Vertex pipeline Value proto message.
Args:
name (str):
Required. The name of the pipeline parameter.
value (Union[int, float, str]):
value (Union[int, float, str, bool, list, dict]):
Required. The value of the pipeline parameter.
Returns:
Expand All @@ -150,6 +154,16 @@ def _get_vertex_value(
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
elif self._parameter_types[name] == "BOOLEAN":
result["boolValue"] = value
elif self._parameter_types[name] == "NUMBER_DOUBLE":
result["numberValue"] = value
elif self._parameter_types[name] == "NUMBER_INTEGER":
result["numberValue"] = value
elif self._parameter_types[name] == "LIST":
result["listValue"] = value
elif self._parameter_types[name] == "STRUCT":
result["structValue"] = value
else:
raise TypeError("Got unknown type of value: {}".format(value))

Expand All @@ -164,19 +178,19 @@ def _parse_runtime_parameters(
Raises:
TypeError: if the parameter type is not one of 'INT', 'DOUBLE', 'STRING'.
"""
runtime_parameters = runtime_config_spec.get("parameters")
if not runtime_parameters:
return None

result = {}
for name, value in runtime_parameters.items():
if "intValue" in value:
result[name] = int(value["intValue"])
elif "doubleValue" in value:
result[name] = float(value["doubleValue"])
elif "stringValue" in value:
result[name] = value["stringValue"]
else:
raise TypeError("Got unknown type of value: {}".format(value))
# 'parameters' are deprecated in IR and changed to 'parameterValues'.
if runtime_config_spec.get("parameterValues") is not None:
return runtime_config_spec.get("parameterValues")

return result
if runtime_config_spec.get("parameters") is not None:
result = {}
for name, value in runtime_config_spec.get("parameters").items():
if "intValue" in value:
result[name] = int(value["intValue"])
elif "doubleValue" in value:
result[name] = float(value["doubleValue"])
elif "stringValue" in value:
result[name] = value["stringValue"]
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
Loading

0 comments on commit 2fc05ca

Please sign in to comment.