From d81d464996f85012ee3deebe408e981a4df473f4 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 7 Jul 2020 09:39:38 -0700 Subject: [PATCH 1/2] change: handle image_uri rename in Airflow model config functions in v2 migration tool This commit also adds some more utility functions for parsing and checking AST nodes. --- .../cli/compatibility/v2/ast_transformer.py | 1 + .../cli/compatibility/v2/modifiers/airflow.py | 31 ++++++- .../v2/modifiers/framework_version.py | 20 ++--- .../compatibility/v2/modifiers/matching.py | 16 ++++ .../cli/compatibility/v2/modifiers/parsing.py | 50 +++++++++++ .../v2/modifiers/renamed_params.py | 30 ++----- .../v2/modifiers/test_airflow.py | 90 ++++++++++++------- .../v2/modifiers/test_matching.py | 5 ++ .../v2/modifiers/test_parsing.py | 51 +++++++++++ 9 files changed, 223 insertions(+), 71 deletions(-) create mode 100644 src/sagemaker/cli/compatibility/v2/modifiers/parsing.py create mode 100644 tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 09bd63d792..f45821858f 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -25,6 +25,7 @@ modifiers.tfs.TensorFlowServingConstructorRenamer(), modifiers.predictors.PredictorConstructorRefactor(), modifiers.airflow.ModelConfigArgModifier(), + modifiers.airflow.ModelConfigImageURIRenamer(), modifiers.renamed_params.DistributionParameterRenamer(), modifiers.renamed_params.S3SessionRenamer(), ] diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py b/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py index 0c4768342a..f69f519468 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py @@ -15,7 +15,7 @@ import ast -from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers import matching, renamed_params from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier FUNCTION_NAMES = ("model_config", "model_config_from_estimator") @@ -61,3 +61,32 @@ def modify_node(self, node): """ instance_type = node.args.pop(0) node.keywords.append(ast.keyword(arg="instance_type", value=instance_type)) + + +class ModelConfigImageURIRenamer(renamed_params.ParamRenamer): + """A class to rename the ``image`` attribute to ``image_uri`` in Airflow model config functions. + + This looks for the following formats: + + - ``model_config`` + - ``airflow.model_config`` + - ``workflow.airflow.model_config`` + - ``sagemaker.workflow.airflow.model_config`` + + where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``. + """ + + @property + def calls_to_modify(self): + """A dictionary mapping Airflow model config functions to their respective namespaces.""" + return FUNCTIONS + + @property + def old_param_name(self): + """The previous name for the image URI argument.""" + return "image" + + @property + def new_param_name(self): + """The new name for the image URI argument.""" + return "image_uri" diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index 9480a4af3d..2eb0bb28ea 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -15,7 +15,7 @@ import ast -from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers import matching, parsing from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier FRAMEWORK_ARG = "framework_version" @@ -98,13 +98,13 @@ def modify_node(self, node): framework, is_model = _framework_from_node(node) # if framework_version is not supplied, get default and append keyword - framework_version = _arg_value(node, FRAMEWORK_ARG) + framework_version = parsing.arg_value(node, FRAMEWORK_ARG) if framework_version is None: framework_version = FRAMEWORK_DEFAULTS[framework] node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version))) # if py_version is not supplied, get a conditional default, and if not None, append keyword - py_version = _arg_value(node, PY_ARG) + py_version = parsing.arg_value(node, PY_ARG) if py_version is None: py_version = _py_version_defaults(framework, framework_version, is_model) if py_version: @@ -175,12 +175,12 @@ def _version_args_needed(node, image_arg): Applies similar logic as ``validate_version_or_image_args`` """ # if image_arg is present, no need to supply version arguments - image_name = _arg_value(node, image_arg) + image_name = parsing.arg_value(node, image_arg) if image_name: return False # if framework_version is None, need args - framework_version = _arg_value(node, FRAMEWORK_ARG) + framework_version = parsing.arg_value(node, FRAMEWORK_ARG) if framework_version is None: return True @@ -188,15 +188,7 @@ def _version_args_needed(node, image_arg): framework, is_model = _framework_from_node(node) expecting_py_version = _py_version_defaults(framework, framework_version, is_model) if expecting_py_version: - py_version = _arg_value(node, PY_ARG) + py_version = parsing.arg_value(node, PY_ARG) return py_version is None return False - - -def _arg_value(node, arg): - """Gets the value associated with the arg keyword, if present""" - for kw in node.keywords: - if kw.arg == arg and kw.value: - return kw.value.s - return None diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/matching.py b/src/sagemaker/cli/compatibility/v2/modifiers/matching.py index feb4060dfc..91eff05b55 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/matching.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/matching.py @@ -15,6 +15,8 @@ import ast +from sagemaker.cli.compatibility.v2.modifiers import parsing + def matches_any(node, name_to_namespaces_dict): """Determines if the ``ast.Call`` node matches any of the provided names and namespaces. @@ -101,3 +103,17 @@ def matches_namespace(node, namespace): name, value = names.pop(), value.value return isinstance(value, ast.Name) and value.id == name + + +def has_arg(node, arg): + """Checks if the call has the given argument. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + arg (str): the name of the argument. + + Returns: + bool: if the node has the given argument. + """ + return parsing.arg_value(node, arg) is not None diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py b/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py new file mode 100644 index 0000000000..932c2bea14 --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py @@ -0,0 +1,50 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Functions for parsing AST nodes.""" +from __future__ import absolute_import + + +def arg_from_keywords(node, arg): + """Retrieves a keyword argument from the node's keywords. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + arg (str): the name of the argument. + + Returns: + ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``. + """ + for kw in node.keywords: + if kw.arg == arg: + return kw + + return None + + +def arg_value(node, arg): + """Retrieves a keyword argument's value from the node's keywords. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + arg (str): the name of the argument. + + Returns: + obj: the keyword argument's value if it is present. Otherwise, this returns ``None``. + """ + keyword = arg_from_keywords(node, arg) + if keyword and keyword.value: + return getattr(keyword.value, keyword.value._fields[0], None) + + return None diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py index c63eb95d51..3247667a89 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py @@ -18,7 +18,7 @@ import ast from abc import abstractmethod -from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers import matching, parsing from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier @@ -54,11 +54,9 @@ def node_should_be_modified(self, node): bool: If the ``ast.Call`` matches the relevant function calls and contains the parameter to be renamed. """ - return matching.matches_any(node, self.calls_to_modify) and self._has_param_to_rename(node) - - def _has_param_to_rename(self, node): - """Checks if the call has the argument that needs to be renamed.""" - return _keyword_from_keywords(node, self.old_param_name) is not None + return matching.matches_any(node, self.calls_to_modify) and matching.has_arg( + node, self.old_param_name + ) def modify_node(self, node): """Modifies the ``ast.Call`` node to rename the attribute. @@ -66,28 +64,10 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents the relevant function call. """ - keyword = _keyword_from_keywords(node, self.old_param_name) + keyword = parsing.arg_from_keywords(node, self.old_param_name) keyword.arg = self.new_param_name -def _keyword_from_keywords(node, param_name): - """Retrieves a keyword argument from the node's keywords. - - Args: - node (ast.Call): a node that represents a function call. For more, - see https://docs.python.org/3/library/ast.html#abstract-grammar. - param_name (str): the name of the argument. - - Returns: - ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``. - """ - for kw in node.keywords: - if kw.arg == param_name: - return kw - - return None - - class DistributionParameterRenamer(ParamRenamer): """A class to rename the ``distributions`` attribute to ``distrbution`` in MXNet and TensorFlow estimators. diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py index a99aa54ecb..5e8eda53eb 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py @@ -17,52 +17,41 @@ from sagemaker.cli.compatibility.v2.modifiers import airflow from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call - -def test_node_should_be_modified_model_config_with_args(): - model_config_calls = ( - "model_config(instance_type, model)", - "airflow.model_config(instance_type, model)", - "workflow.airflow.model_config(instance_type, model)", - "sagemaker.workflow.airflow.model_config(instance_type, model)", - "model_config_from_estimator(instance_type, model)", - "airflow.model_config_from_estimator(instance_type, model)", - "workflow.airflow.model_config_from_estimator(instance_type, model)", - "sagemaker.workflow.airflow.model_config_from_estimator(instance_type, model)", - ) - +MODEL_CONFIG_CALL_TEMPLATES = ( + "model_config({})", + "airflow.model_config({})", + "workflow.airflow.model_config({})", + "sagemaker.workflow.airflow.model_config({})", + "model_config_from_estimator({})", + "airflow.model_config_from_estimator({})", + "workflow.airflow.model_config_from_estimator({})", + "sagemaker.workflow.airflow.model_config_from_estimator({})", +) + + +def test_arg_order_node_should_be_modified_model_config_with_args(): modifier = airflow.ModelConfigArgModifier() - for call in model_config_calls: - node = ast_call(call) + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("instance_type, model")) assert modifier.node_should_be_modified(node) is True -def test_node_should_be_modified_model_config_without_args(): - model_config_calls = ( - "model_config()", - "airflow.model_config()", - "workflow.airflow.model_config()", - "sagemaker.workflow.airflow.model_config()", - "model_config_from_estimator()", - "airflow.model_config_from_estimator()", - "workflow.airflow.model_config_from_estimator()", - "sagemaker.workflow.airflow.model_config_from_estimator()", - ) - +def test_arg_order_node_should_be_modified_model_config_without_args(): modifier = airflow.ModelConfigArgModifier() - for call in model_config_calls: - node = ast_call(call) + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("")) assert modifier.node_should_be_modified(node) is False -def test_node_should_be_modified_random_function_call(): +def test_arg_order_node_should_be_modified_random_function_call(): node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()") modifier = airflow.ModelConfigArgModifier() assert modifier.node_should_be_modified(node) is False -def test_modify_node(): +def test_arg_order_modify_node(): model_config_calls = ( ("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"), ( @@ -89,3 +78,42 @@ def test_modify_node(): node = ast_call(call) modifier.modify_node(node) assert expected == pasta.dump(node) + + +def test_image_arg_node_should_be_modified_model_config_with_arg(): + modifier = airflow.ModelConfigImageURIRenamer() + + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("image=my_image")) + assert modifier.node_should_be_modified(node) is True + + +def test_image_arg_node_should_be_modified_model_config_without_arg(): + modifier = airflow.ModelConfigImageURIRenamer() + + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("")) + assert modifier.node_should_be_modified(node) is False + + +def test_image_arg_node_should_be_modified_random_function_call(): + node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()") + modifier = airflow.ModelConfigImageURIRenamer() + assert modifier.node_should_be_modified(node) is False + + +def test_image_arg_modify_node(): + model_config_calls = ( + ("model_config(image='image:latest')", "model_config(image_uri='image:latest')"), + ( + "model_config_from_estimator(image=my_image)", + "model_config_from_estimator(image_uri=my_image)", + ), + ) + + modifier = airflow.ModelConfigImageURIRenamer() + + for call, expected in model_config_calls: + node = ast_call(call) + modifier.modify_node(node) + assert expected == pasta.dump(node) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py index 02350e209a..c02771e18c 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py @@ -66,3 +66,8 @@ def test_matches_attr(): def test_matches_namespace(): assert matching.matches_namespace(ast_call("sagemaker.mxnet.MXNet()"), "sagemaker.mxnet") assert not matching.matches_namespace(ast_call("sagemaker.KMeans()"), "sagemaker.mxnet") + + +def test_has_arg(): + assert matching.has_arg(ast_call("MXNet(framework_version=mxnet_version)"), "framework_version") + assert not matching.has_arg(ast_call("MXNet()"), "framework_version") diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py new file mode 100644 index 0000000000..2f54c1ca73 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py @@ -0,0 +1,51 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.cli.compatibility.v2.modifiers import parsing +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call + + +def test_arg_from_keywords(): + kw_name = "framework_version" + kw_value = "1.6.0" + + call = ast_call("MXNet({}='{}', py_version='py3', entry_point='run')".format(kw_name, kw_value)) + returned_kw = parsing.arg_from_keywords(call, kw_name) + + assert kw_name == returned_kw.arg + assert kw_value == returned_kw.value.s + + +def test_arg_from_keywords_absent_keyword(): + call = ast_call("MXNet(entry_point='run')") + assert parsing.arg_from_keywords(call, "framework_version") is None + + +def test_arg_value(): + call = ast_call("MXNet(framework_version='1.6.0')") + assert "1.6.0" == parsing.arg_value(call, "framework_version") + + call = ast_call("MXNet(framework_version=mxnet_version)") + assert "mxnet_version" == parsing.arg_value(call, "framework_version") + + call = ast_call("MXNet(instance_count=1)") + assert 1 == parsing.arg_value(call, "instance_count") + + call = ast_call("MXNet(enable_network_isolation=True)") + assert parsing.arg_value(call, "enable_network_isolation") is True + + +def test_arg_value_absent_keyword(): + call = ast_call("MXNet(entry_point='run')") + assert parsing.arg_value(call, "framework_version") is None From ab405bcf6cba454bec5d97d4a20fc028f2e4e0bd Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 7 Jul 2020 12:47:58 -0700 Subject: [PATCH 2/2] address PR comment --- .../v2/modifiers/framework_version.py | 16 ++++++++-------- .../cli/compatibility/v2/modifiers/matching.py | 5 ++++- .../cli/compatibility/v2/modifiers/parsing.py | 13 +++++++++---- .../compatibility/v2/modifiers/test_parsing.py | 13 +++++++++++-- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index 2eb0bb28ea..26fcf7be0f 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -98,14 +98,14 @@ def modify_node(self, node): framework, is_model = _framework_from_node(node) # if framework_version is not supplied, get default and append keyword - framework_version = parsing.arg_value(node, FRAMEWORK_ARG) - if framework_version is None: + if matching.has_arg(node, FRAMEWORK_ARG): + framework_version = parsing.arg_value(node, FRAMEWORK_ARG) + else: framework_version = FRAMEWORK_DEFAULTS[framework] node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version))) # if py_version is not supplied, get a conditional default, and if not None, append keyword - py_version = parsing.arg_value(node, PY_ARG) - if py_version is None: + if not matching.has_arg(node, PY_ARG): py_version = _py_version_defaults(framework, framework_version, is_model) if py_version: node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version))) @@ -175,13 +175,13 @@ def _version_args_needed(node, image_arg): Applies similar logic as ``validate_version_or_image_args`` """ # if image_arg is present, no need to supply version arguments - image_name = parsing.arg_value(node, image_arg) - if image_name: + if matching.has_arg(node, image_arg): return False # if framework_version is None, need args - framework_version = parsing.arg_value(node, FRAMEWORK_ARG) - if framework_version is None: + if matching.has_arg(node, FRAMEWORK_ARG): + framework_version = parsing.arg_value(node, FRAMEWORK_ARG) + else: return True # check if we expect py_version and we don't get it -- framework and model dependent diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/matching.py b/src/sagemaker/cli/compatibility/v2/modifiers/matching.py index 91eff05b55..a84a6b9ca9 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/matching.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/matching.py @@ -116,4 +116,7 @@ def has_arg(node, arg): Returns: bool: if the node has the given argument. """ - return parsing.arg_value(node, arg) is not None + try: + return parsing.arg_value(node, arg) is not None + except KeyError: + return False diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py b/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py index 932c2bea14..02c33117e6 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py @@ -13,6 +13,8 @@ """Functions for parsing AST nodes.""" from __future__ import absolute_import +import pasta + def arg_from_keywords(node, arg): """Retrieves a keyword argument from the node's keywords. @@ -41,10 +43,13 @@ def arg_value(node, arg): arg (str): the name of the argument. Returns: - obj: the keyword argument's value if it is present. Otherwise, this returns ``None``. + obj: the keyword argument's value. + + Raises: + KeyError: if the node's keywords do not contain the argument. """ keyword = arg_from_keywords(node, arg) - if keyword and keyword.value: - return getattr(keyword.value, keyword.value._fields[0], None) + if keyword is None: + raise KeyError("arg '{}' not found in call: {}".format(arg, pasta.dump(node))) - return None + return getattr(keyword.value, keyword.value._fields[0], None) if keyword.value else None diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py index 2f54c1ca73..355a7d4ffd 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py @@ -12,6 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + from sagemaker.cli.compatibility.v2.modifiers import parsing from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call @@ -45,7 +47,14 @@ def test_arg_value(): call = ast_call("MXNet(enable_network_isolation=True)") assert parsing.arg_value(call, "enable_network_isolation") is True + call = ast_call("MXNet(source_dir=None)") + assert parsing.arg_value(call, "source_dir") is None + def test_arg_value_absent_keyword(): - call = ast_call("MXNet(entry_point='run')") - assert parsing.arg_value(call, "framework_version") is None + code = "MXNet(entry_point='run')" + + with pytest.raises(KeyError) as e: + parsing.arg_value(ast_call(code), "framework_version") + + assert "arg 'framework_version' not found in call: {}".format(code) in str(e.value)