From 6c689d2d7093abde45a4ef944abf462c1dd960d3 Mon Sep 17 00:00:00 2001 From: Yaqi Date: Thu, 4 Nov 2021 11:37:48 -0700 Subject: [PATCH] fix key to parameters --- .../cloud/aiplatform/utils/pipeline_utils.py | 44 ++++++++----------- tests/unit/aiplatform/test_pipeline_jobs.py | 34 ++------------ 2 files changed, 22 insertions(+), 56 deletions(-) diff --git a/google/cloud/aiplatform/utils/pipeline_utils.py b/google/cloud/aiplatform/utils/pipeline_utils.py index f4d6a674cfe..0f9fe269682 100644 --- a/google/cloud/aiplatform/utils/pipeline_utils.py +++ b/google/cloud/aiplatform/utils/pipeline_utils.py @@ -64,13 +64,10 @@ def from_job_spec_json( A PipelineRuntimeConfigBuilder object. """ runtime_config_spec = job_spec["runtimeConfig"] - input_definitions = ( - job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {} - ) parameter_input_definitions = ( - input_definitions.get("parameterValues") - or input_definitions.get("parameters") - or {} + job_spec["pipelineSpec"]["root"] + .get("inputDefinitions", {}) + .get("parameters", {}) ) schema_version = job_spec["pipelineSpec"].get("schemaVersion") @@ -141,7 +138,7 @@ def build(self) -> Dict[str, Any]: def _get_vertex_value( self, name: str, value: Union[int, float, str, bool, list, dict] - ) -> Dict[str, Any]: + ) -> Union[Dict[str, Any], int, float, str, bool, list, dict]: """Converts primitive values into Vertex pipeline Value proto message. Args: @@ -166,26 +163,21 @@ def _get_vertex_value( "pipeline job input definitions.".format(name) ) - result = {} - if self._parameter_types[name] == "INT": - result["intValue"] = value - elif self._parameter_types[name] == "DOUBLE": - result["doubleValue"] = value - elif self._parameter_types[name] == "STRING": - result["stringValue"] = value - elif self._parameter_types[name] == "BOOLEAN": - result["boolValue"] = value - elif self._parameter_types[name] == "NUMBER_DOUBLE": - result["numberValue"] = value - elif self._parameter_types[name] == "NUMBER_INTEGER": - result["numberValue"] = value - elif self._parameter_types[name] == "LIST": - result["listValue"] = value - elif self._parameter_types[name] == "STRUCT": - result["structValue"] = value + if packaging.version.parse(self._schema_version) <= packaging.version.parse( + "2.0.0" + ): + result = {} + if self._parameter_types[name] == "INT": + result["intValue"] = value + elif self._parameter_types[name] == "DOUBLE": + result["doubleValue"] = value + elif self._parameter_types[name] == "STRING": + result["stringValue"] = value + else: + raise TypeError("Got unknown type of value: {}".format(value)) + return result else: - raise TypeError("Got unknown type of value: {}".format(value)) - return result + return value def _parse_runtime_parameters( diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index e2ba8c62bb1..81f14a7ead0 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -78,7 +78,7 @@ "root": { "dag": {"tasks": {}}, "inputDefinitions": { - "parameterValues": { + "parameters": { "string_param": {"parameterType": "STRING"}, "bool_param": {"parameterType": "BOOLEAN"}, "double_param": {"parameterType": "NUMBER_DOUBLE"}, @@ -98,17 +98,7 @@ "pipelineSpec": _TEST_PIPELINE_SPEC_LEGACY, } _TEST_PIPELINE_JOB = { - "runtimeConfig": { - "parameterValues": { - "string_param": "lorem ipsum", - "bool_param": True, - "double_param": 12.34, - "int_param": 5678, - "list_int_param": [123, 456, 789], - "list_string_param": ["lorem", "ipsum"], - "struct_param": {"key1": 12345, "key2": 67890}, - }, - }, + "runtimeConfig": {"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES}, "pipelineSpec": _TEST_PIPELINE_SPEC, } @@ -282,15 +272,7 @@ def test_run_call_pipeline_service_create( expected_runtime_config_dict = { "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, - "parameterValues": { - "bool_param": {"boolValue": True}, - "double_param": {"numberValue": 12.34}, - "int_param": {"numberValue": 5678}, - "list_int_param": {"listValue": [123, 456, 789]}, - "list_string_param": {"listValue": ["lorem", "ipsum"]}, - "struct_param": {"structValue": {"key1": 12345, "key2": 67890}}, - "string_param": {"stringValue": "hello world"}, - }, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, } runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb json_format.ParseDict(expected_runtime_config_dict, runtime_config) @@ -425,15 +407,7 @@ def test_submit_call_pipeline_service_pipeline_job_create( expected_runtime_config_dict = { "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, - "parameterValues": { - "bool_param": {"boolValue": True}, - "double_param": {"numberValue": 12.34}, - "int_param": {"numberValue": 5678}, - "list_int_param": {"listValue": [123, 456, 789]}, - "list_string_param": {"listValue": ["lorem", "ipsum"]}, - "struct_param": {"structValue": {"key1": 12345, "key2": 67890}}, - "string_param": {"stringValue": "hello world"}, - }, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, } runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb json_format.ParseDict(expected_runtime_config_dict, runtime_config)