diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 09bd63d792..f45821858f 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -25,6 +25,7 @@ modifiers.tfs.TensorFlowServingConstructorRenamer(), modifiers.predictors.PredictorConstructorRefactor(), modifiers.airflow.ModelConfigArgModifier(), + modifiers.airflow.ModelConfigImageURIRenamer(), modifiers.renamed_params.DistributionParameterRenamer(), modifiers.renamed_params.S3SessionRenamer(), ] diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py b/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py index 0c4768342a..f69f519468 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py @@ -15,7 +15,7 @@ import ast -from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers import matching, renamed_params from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier FUNCTION_NAMES = ("model_config", "model_config_from_estimator") @@ -61,3 +61,32 @@ def modify_node(self, node): """ instance_type = node.args.pop(0) node.keywords.append(ast.keyword(arg="instance_type", value=instance_type)) + + +class ModelConfigImageURIRenamer(renamed_params.ParamRenamer): + """A class to rename the ``image`` attribute to ``image_uri`` in Airflow model config functions. + + This looks for the following formats: + + - ``model_config`` + - ``airflow.model_config`` + - ``workflow.airflow.model_config`` + - ``sagemaker.workflow.airflow.model_config`` + + where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``. + """ + + @property + def calls_to_modify(self): + """A dictionary mapping Airflow model config functions to their respective namespaces.""" + return FUNCTIONS + + @property + def old_param_name(self): + """The previous name for the image URI argument.""" + return "image" + + @property + def new_param_name(self): + """The new name for the image URI argument.""" + return "image_uri" diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index 9480a4af3d..26fcf7be0f 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -15,7 +15,7 @@ import ast -from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers import matching, parsing from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier FRAMEWORK_ARG = "framework_version" @@ -98,14 +98,14 @@ def modify_node(self, node): framework, is_model = _framework_from_node(node) # if framework_version is not supplied, get default and append keyword - framework_version = _arg_value(node, FRAMEWORK_ARG) - if framework_version is None: + if matching.has_arg(node, FRAMEWORK_ARG): + framework_version = parsing.arg_value(node, FRAMEWORK_ARG) + else: framework_version = FRAMEWORK_DEFAULTS[framework] node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version))) # if py_version is not supplied, get a conditional default, and if not None, append keyword - py_version = _arg_value(node, PY_ARG) - if py_version is None: + if not matching.has_arg(node, PY_ARG): py_version = _py_version_defaults(framework, framework_version, is_model) if py_version: node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version))) @@ -175,28 +175,20 @@ def _version_args_needed(node, image_arg): Applies similar logic as ``validate_version_or_image_args`` """ # if image_arg is present, no need to supply version arguments - image_name = _arg_value(node, image_arg) - if image_name: + if matching.has_arg(node, image_arg): return False # if framework_version is None, need args - framework_version = _arg_value(node, FRAMEWORK_ARG) - if framework_version is None: + if matching.has_arg(node, FRAMEWORK_ARG): + framework_version = parsing.arg_value(node, FRAMEWORK_ARG) + else: return True # check if we expect py_version and we don't get it -- framework and model dependent framework, is_model = _framework_from_node(node) expecting_py_version = _py_version_defaults(framework, framework_version, is_model) if expecting_py_version: - py_version = _arg_value(node, PY_ARG) + py_version = parsing.arg_value(node, PY_ARG) return py_version is None return False - - -def _arg_value(node, arg): - """Gets the value associated with the arg keyword, if present""" - for kw in node.keywords: - if kw.arg == arg and kw.value: - return kw.value.s - return None diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/matching.py b/src/sagemaker/cli/compatibility/v2/modifiers/matching.py index feb4060dfc..a84a6b9ca9 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/matching.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/matching.py @@ -15,6 +15,8 @@ import ast +from sagemaker.cli.compatibility.v2.modifiers import parsing + def matches_any(node, name_to_namespaces_dict): """Determines if the ``ast.Call`` node matches any of the provided names and namespaces. @@ -101,3 +103,20 @@ def matches_namespace(node, namespace): name, value = names.pop(), value.value return isinstance(value, ast.Name) and value.id == name + + +def has_arg(node, arg): + """Checks if the call has the given argument. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + arg (str): the name of the argument. + + Returns: + bool: if the node has the given argument. + """ + try: + return parsing.arg_value(node, arg) is not None + except KeyError: + return False diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py b/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py new file mode 100644 index 0000000000..02c33117e6 --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/parsing.py @@ -0,0 +1,55 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Functions for parsing AST nodes.""" +from __future__ import absolute_import + +import pasta + + +def arg_from_keywords(node, arg): + """Retrieves a keyword argument from the node's keywords. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + arg (str): the name of the argument. + + Returns: + ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``. + """ + for kw in node.keywords: + if kw.arg == arg: + return kw + + return None + + +def arg_value(node, arg): + """Retrieves a keyword argument's value from the node's keywords. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + arg (str): the name of the argument. + + Returns: + obj: the keyword argument's value. + + Raises: + KeyError: if the node's keywords do not contain the argument. + """ + keyword = arg_from_keywords(node, arg) + if keyword is None: + raise KeyError("arg '{}' not found in call: {}".format(arg, pasta.dump(node))) + + return getattr(keyword.value, keyword.value._fields[0], None) if keyword.value else None diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py index c63eb95d51..3247667a89 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py @@ -18,7 +18,7 @@ import ast from abc import abstractmethod -from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers import matching, parsing from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier @@ -54,11 +54,9 @@ def node_should_be_modified(self, node): bool: If the ``ast.Call`` matches the relevant function calls and contains the parameter to be renamed. """ - return matching.matches_any(node, self.calls_to_modify) and self._has_param_to_rename(node) - - def _has_param_to_rename(self, node): - """Checks if the call has the argument that needs to be renamed.""" - return _keyword_from_keywords(node, self.old_param_name) is not None + return matching.matches_any(node, self.calls_to_modify) and matching.has_arg( + node, self.old_param_name + ) def modify_node(self, node): """Modifies the ``ast.Call`` node to rename the attribute. @@ -66,28 +64,10 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents the relevant function call. """ - keyword = _keyword_from_keywords(node, self.old_param_name) + keyword = parsing.arg_from_keywords(node, self.old_param_name) keyword.arg = self.new_param_name -def _keyword_from_keywords(node, param_name): - """Retrieves a keyword argument from the node's keywords. - - Args: - node (ast.Call): a node that represents a function call. For more, - see https://docs.python.org/3/library/ast.html#abstract-grammar. - param_name (str): the name of the argument. - - Returns: - ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``. - """ - for kw in node.keywords: - if kw.arg == param_name: - return kw - - return None - - class DistributionParameterRenamer(ParamRenamer): """A class to rename the ``distributions`` attribute to ``distrbution`` in MXNet and TensorFlow estimators. diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py index a99aa54ecb..5e8eda53eb 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py @@ -17,52 +17,41 @@ from sagemaker.cli.compatibility.v2.modifiers import airflow from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call - -def test_node_should_be_modified_model_config_with_args(): - model_config_calls = ( - "model_config(instance_type, model)", - "airflow.model_config(instance_type, model)", - "workflow.airflow.model_config(instance_type, model)", - "sagemaker.workflow.airflow.model_config(instance_type, model)", - "model_config_from_estimator(instance_type, model)", - "airflow.model_config_from_estimator(instance_type, model)", - "workflow.airflow.model_config_from_estimator(instance_type, model)", - "sagemaker.workflow.airflow.model_config_from_estimator(instance_type, model)", - ) - +MODEL_CONFIG_CALL_TEMPLATES = ( + "model_config({})", + "airflow.model_config({})", + "workflow.airflow.model_config({})", + "sagemaker.workflow.airflow.model_config({})", + "model_config_from_estimator({})", + "airflow.model_config_from_estimator({})", + "workflow.airflow.model_config_from_estimator({})", + "sagemaker.workflow.airflow.model_config_from_estimator({})", +) + + +def test_arg_order_node_should_be_modified_model_config_with_args(): modifier = airflow.ModelConfigArgModifier() - for call in model_config_calls: - node = ast_call(call) + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("instance_type, model")) assert modifier.node_should_be_modified(node) is True -def test_node_should_be_modified_model_config_without_args(): - model_config_calls = ( - "model_config()", - "airflow.model_config()", - "workflow.airflow.model_config()", - "sagemaker.workflow.airflow.model_config()", - "model_config_from_estimator()", - "airflow.model_config_from_estimator()", - "workflow.airflow.model_config_from_estimator()", - "sagemaker.workflow.airflow.model_config_from_estimator()", - ) - +def test_arg_order_node_should_be_modified_model_config_without_args(): modifier = airflow.ModelConfigArgModifier() - for call in model_config_calls: - node = ast_call(call) + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("")) assert modifier.node_should_be_modified(node) is False -def test_node_should_be_modified_random_function_call(): +def test_arg_order_node_should_be_modified_random_function_call(): node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()") modifier = airflow.ModelConfigArgModifier() assert modifier.node_should_be_modified(node) is False -def test_modify_node(): +def test_arg_order_modify_node(): model_config_calls = ( ("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"), ( @@ -89,3 +78,42 @@ def test_modify_node(): node = ast_call(call) modifier.modify_node(node) assert expected == pasta.dump(node) + + +def test_image_arg_node_should_be_modified_model_config_with_arg(): + modifier = airflow.ModelConfigImageURIRenamer() + + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("image=my_image")) + assert modifier.node_should_be_modified(node) is True + + +def test_image_arg_node_should_be_modified_model_config_without_arg(): + modifier = airflow.ModelConfigImageURIRenamer() + + for template in MODEL_CONFIG_CALL_TEMPLATES: + node = ast_call(template.format("")) + assert modifier.node_should_be_modified(node) is False + + +def test_image_arg_node_should_be_modified_random_function_call(): + node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()") + modifier = airflow.ModelConfigImageURIRenamer() + assert modifier.node_should_be_modified(node) is False + + +def test_image_arg_modify_node(): + model_config_calls = ( + ("model_config(image='image:latest')", "model_config(image_uri='image:latest')"), + ( + "model_config_from_estimator(image=my_image)", + "model_config_from_estimator(image_uri=my_image)", + ), + ) + + modifier = airflow.ModelConfigImageURIRenamer() + + for call, expected in model_config_calls: + node = ast_call(call) + modifier.modify_node(node) + assert expected == pasta.dump(node) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py index 02350e209a..c02771e18c 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py @@ -66,3 +66,8 @@ def test_matches_attr(): def test_matches_namespace(): assert matching.matches_namespace(ast_call("sagemaker.mxnet.MXNet()"), "sagemaker.mxnet") assert not matching.matches_namespace(ast_call("sagemaker.KMeans()"), "sagemaker.mxnet") + + +def test_has_arg(): + assert matching.has_arg(ast_call("MXNet(framework_version=mxnet_version)"), "framework_version") + assert not matching.has_arg(ast_call("MXNet()"), "framework_version") diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py new file mode 100644 index 0000000000..355a7d4ffd --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py @@ -0,0 +1,60 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import parsing +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call + + +def test_arg_from_keywords(): + kw_name = "framework_version" + kw_value = "1.6.0" + + call = ast_call("MXNet({}='{}', py_version='py3', entry_point='run')".format(kw_name, kw_value)) + returned_kw = parsing.arg_from_keywords(call, kw_name) + + assert kw_name == returned_kw.arg + assert kw_value == returned_kw.value.s + + +def test_arg_from_keywords_absent_keyword(): + call = ast_call("MXNet(entry_point='run')") + assert parsing.arg_from_keywords(call, "framework_version") is None + + +def test_arg_value(): + call = ast_call("MXNet(framework_version='1.6.0')") + assert "1.6.0" == parsing.arg_value(call, "framework_version") + + call = ast_call("MXNet(framework_version=mxnet_version)") + assert "mxnet_version" == parsing.arg_value(call, "framework_version") + + call = ast_call("MXNet(instance_count=1)") + assert 1 == parsing.arg_value(call, "instance_count") + + call = ast_call("MXNet(enable_network_isolation=True)") + assert parsing.arg_value(call, "enable_network_isolation") is True + + call = ast_call("MXNet(source_dir=None)") + assert parsing.arg_value(call, "source_dir") is None + + +def test_arg_value_absent_keyword(): + code = "MXNet(entry_point='run')" + + with pytest.raises(KeyError) as e: + parsing.arg_value(ast_call(code), "framework_version") + + assert "arg 'framework_version' not found in call: {}".format(code) in str(e.value)