From 3581bf189b62deef42599b3156a856350717fae2 Mon Sep 17 00:00:00 2001 From: Yaqi Date: Thu, 4 Nov 2021 11:37:48 -0700 Subject: [PATCH] fix key to parameters --- .../cloud/aiplatform/utils/pipeline_utils.py | 56 ++++++++----------- tests/unit/aiplatform/test_pipeline_jobs.py | 34 ++--------- 2 files changed, 28 insertions(+), 62 deletions(-) diff --git a/google/cloud/aiplatform/utils/pipeline_utils.py b/google/cloud/aiplatform/utils/pipeline_utils.py index f4d6a674cf..506dbdff26 100644 --- a/google/cloud/aiplatform/utils/pipeline_utils.py +++ b/google/cloud/aiplatform/utils/pipeline_utils.py @@ -29,26 +29,26 @@ class PipelineRuntimeConfigBuilder(object): def __init__( self, pipeline_root: str, + schema_version: str, parameter_types: Mapping[str, str], parameter_values: Optional[Dict[str, Any]] = None, - schema_version: Optional[str] = None, ): """Creates a PipelineRuntimeConfigBuilder object. Args: pipeline_root (str): Required. The root of the pipeline outputs. + schema_version (str): + Required. Schema version of the IR. This field determines the fields supported in current version of IR. parameter_types (Mapping[str, str]): Required. The mapping from pipeline parameter name to its type. parameter_values (Dict[str, Any]): Optional. The mapping from runtime parameter name to its value. - schema_version (str): - Optional. Schema version of the IR. This field determines the fields supported in current version of IR. """ self._pipeline_root = pipeline_root + self._schema_version = schema_version self._parameter_types = parameter_types self._parameter_values = copy.deepcopy(parameter_values or {}) - self._schema_version = schema_version or "2.0.0" @classmethod def from_job_spec_json( @@ -64,15 +64,12 @@ def from_job_spec_json( A PipelineRuntimeConfigBuilder object. """ runtime_config_spec = job_spec["runtimeConfig"] - input_definitions = ( - job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {} - ) parameter_input_definitions = ( - input_definitions.get("parameterValues") - or input_definitions.get("parameters") - or {} + job_spec["pipelineSpec"]["root"] + .get("inputDefinitions", {}) + .get("parameters", {}) ) - schema_version = job_spec["pipelineSpec"].get("schemaVersion") + schema_version = job_spec["pipelineSpec"]["schemaVersion"] # 'type' is deprecated in IR and change to 'parameterType'. parameter_types = { @@ -82,7 +79,7 @@ def from_job_spec_json( pipeline_root = runtime_config_spec.get("gcsOutputDirectory") parameter_values = _parse_runtime_parameters(runtime_config_spec) - return cls(pipeline_root, parameter_types, parameter_values, schema_version) + return cls(pipeline_root, schema_version, parameter_types, parameter_values) def update_pipeline_root(self, pipeline_root: Optional[str]) -> None: """Updates pipeline_root value. @@ -141,7 +138,7 @@ def build(self) -> Dict[str, Any]: def _get_vertex_value( self, name: str, value: Union[int, float, str, bool, list, dict] - ) -> Dict[str, Any]: + ) -> Union[Dict[str, Any], int, float, str, bool, list, dict]: """Converts primitive values into Vertex pipeline Value proto message. Args: @@ -166,26 +163,21 @@ def _get_vertex_value( "pipeline job input definitions.".format(name) ) - result = {} - if self._parameter_types[name] == "INT": - result["intValue"] = value - elif self._parameter_types[name] == "DOUBLE": - result["doubleValue"] = value - elif self._parameter_types[name] == "STRING": - result["stringValue"] = value - elif self._parameter_types[name] == "BOOLEAN": - result["boolValue"] = value - elif self._parameter_types[name] == "NUMBER_DOUBLE": - result["numberValue"] = value - elif self._parameter_types[name] == "NUMBER_INTEGER": - result["numberValue"] = value - elif self._parameter_types[name] == "LIST": - result["listValue"] = value - elif self._parameter_types[name] == "STRUCT": - result["structValue"] = value + if packaging.version.parse(self._schema_version) <= packaging.version.parse( + "2.0.0" + ): + result = {} + if self._parameter_types[name] == "INT": + result["intValue"] = value + elif self._parameter_types[name] == "DOUBLE": + result["doubleValue"] = value + elif self._parameter_types[name] == "STRING": + result["stringValue"] = value + else: + raise TypeError("Got unknown type of value: {}".format(value)) + return result else: - raise TypeError("Got unknown type of value: {}".format(value)) - return result + return value def _parse_runtime_parameters( diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index e2ba8c62bb..81f14a7ead 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -78,7 +78,7 @@ "root": { "dag": {"tasks": {}}, "inputDefinitions": { - "parameterValues": { + "parameters": { "string_param": {"parameterType": "STRING"}, "bool_param": {"parameterType": "BOOLEAN"}, "double_param": {"parameterType": "NUMBER_DOUBLE"}, @@ -98,17 +98,7 @@ "pipelineSpec": _TEST_PIPELINE_SPEC_LEGACY, } _TEST_PIPELINE_JOB = { - "runtimeConfig": { - "parameterValues": { - "string_param": "lorem ipsum", - "bool_param": True, - "double_param": 12.34, - "int_param": 5678, - "list_int_param": [123, 456, 789], - "list_string_param": ["lorem", "ipsum"], - "struct_param": {"key1": 12345, "key2": 67890}, - }, - }, + "runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES}, "pipelineSpec": _TEST_PIPELINE_SPEC, } @@ -282,15 +272,7 @@ def test_run_call_pipeline_service_create( expected_runtime_config_dict = { "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, - "parameterValues": { - "bool_param": {"boolValue": True}, - "double_param": {"numberValue": 12.34}, - "int_param": {"numberValue": 5678}, - "list_int_param": {"listValue": [123, 456, 789]}, - "list_string_param": {"listValue": ["lorem", "ipsum"]}, - "struct_param": {"structValue": {"key1": 12345, "key2": 67890}}, - "string_param": {"stringValue": "hello world"}, - }, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, } runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb json_format.ParseDict(expected_runtime_config_dict, runtime_config) @@ -425,15 +407,7 @@ def test_submit_call_pipeline_service_pipeline_job_create( expected_runtime_config_dict = { "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, - "parameterValues": { - "bool_param": {"boolValue": True}, - "double_param": {"numberValue": 12.34}, - "int_param": {"numberValue": 5678}, - "list_int_param": {"listValue": [123, 456, 789]}, - "list_string_param": {"listValue": ["lorem", "ipsum"]}, - "struct_param": {"structValue": {"key1": 12345, "key2": 67890}}, - "string_param": {"stringValue": "hello world"}, - }, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, } runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb json_format.ParseDict(expected_runtime_config_dict, runtime_config)