diff --git a/buildspec-unittests.yml b/buildspec-unittests.yml index 41a6f68621..2a61440344 100644 --- a/buildspec-unittests.yml +++ b/buildspec-unittests.yml @@ -18,5 +18,12 @@ phases: - start_time=`date +%s` - AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN= AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION= - tox -e py27,py36,py37 --parallel all -- tests/unit - - ./ci-scripts/displaytime.sh 'py27,py36,py37 unit' $start_time + tox -e py36,py37 --parallel all -- tests/unit + - ./ci-scripts/displaytime.sh 'py36,py37 unit' $start_time + + # Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. + - start_time=`date +%s` + - AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN= + AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION= + IGNORE_COVERAGE=- tox -e py27 --parallel all -- tests/unit + - ./ci-scripts/displaytime.sh 'py27 unit' $start_time diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 17eb53f9d8..b50882d21a 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -15,9 +15,12 @@ import ast -from sagemaker.cli.compatibility.v2.modifiers import framework_version +from sagemaker.cli.compatibility.v2 import modifiers -FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()] +FUNCTION_CALL_MODIFIERS = [ + modifiers.framework_version.FrameworkVersionEnforcer(), + modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(), +] class ASTTransformer(ast.NodeTransformer): @@ -38,4 +41,6 @@ def visit_Call(self, node): """ for function_checker in FUNCTION_CALL_MODIFIERS: function_checker.check_and_modify_node(node) + + ast.fix_missing_locations(node) return node diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py index 9fca9c35da..0a20044b5a 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py @@ -12,3 +12,8 @@ # language governing permissions and limitations under the License. """Classes for modifying AST nodes""" from __future__ import absolute_import + +from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused) + framework_version, + tf_legacy_mode, +) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index 2e081e885d..848625a58a 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -66,8 +66,7 @@ def _is_framework_constructor(self, node): """ # Check for call if isinstance(node.func, ast.Name): - if node.func.id in FRAMEWORK_CLASSES: - return True + return node.func.id in FRAMEWORK_CLASSES # Check for sagemaker.. call ends_with_framework_constructor = ( diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py new file mode 100644 index 0000000000..6bc5ec473a --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py @@ -0,0 +1,150 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes to modify TensorFlow legacy mode code to be compatible with SageMaker Python SDK v2.""" +# TODO: handle fit(run_tensorboard_locally=True) +from __future__ import absolute_import + +import ast + +import six + +from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier + + +class TensorFlowLegacyModeConstructorUpgrader(Modifier): + """A class to turn legacy mode parameters into hyperparameters when + instantiating a TensorFlow estimator. + """ + + LEGACY_MODE_PARAMETERS = ( + "checkpoint_path", + "evaluation_steps", + "requirements_file", + "training_steps", + ) + + def node_should_be_modified(self, node): + """Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with legacy mode. + + This looks for the following formats: + + - ``TensorFlow`` + - ``sagemaker.tensorflow.TensorFlow`` + + Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified, + and (2) if ``py_version`` is ``py2`` or not specified. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with legacy mode. + """ + return self._is_tf_constructor(node) and self._is_legacy_mode(node) + + def _is_tf_constructor(self, node): + """Checks if the ``ast.Call`` node represents a call of the form + ``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``. + """ + # Check for TensorFlow() + if isinstance(node.func, ast.Name): + return node.func.id == "TensorFlow" + + # Check for sagemaker.tensorflow.TensorFlow() + ends_with_tensorflow_constructor = ( + isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow" + ) + + is_in_tensorflow_module = ( + isinstance(node.func.value, ast.Attribute) + and node.func.value.attr == "tensorflow" + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id == "sagemaker" + ) + + return ends_with_tensorflow_constructor and is_in_tensorflow_module + + def _is_legacy_mode(self, node): + """Checks if the ``ast.Call`` node's keywords signal using legacy mode.""" + script_mode = False + py_version = "py2" + + for kw in node.keywords: + if kw.arg == "script_mode": + script_mode = bool(kw.value.value) + if kw.arg == "py_version": + py_version = kw.value.s + + return not (py_version.startswith("py3") or script_mode) + + def modify_node(self, node): + """Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters + into hyperparameters and set ``script_mode=False``. + + The parameters that are converted into hyperparameters: + + - ``training_steps`` + - ``evaluation_steps`` + - ``checkpoint_path`` + - ``requirements_file`` + + Args: + node (ast.Call): a node that represents a TensorFlow constructor. + """ + base_hps = {} + additional_hps = {} + kw_to_remove = [] # remove keyword args after so that none are skipped during iteration + + for kw in node.keywords: + if kw.arg == "script_mode": + # remove here because is set to False later regardless of current value + kw_to_remove.append(kw) + if kw.arg == "hyperparameters" and kw.value: + base_hps = dict(zip(kw.value.keys, kw.value.values)) + kw_to_remove.append(kw) + if kw.arg in self.LEGACY_MODE_PARAMETERS and kw.value: + hp_key = self._hyperparameter_key_for_param(kw.arg) + additional_hps[hp_key] = kw.value + kw_to_remove.append(kw) + + self._remove_keywords(node, kw_to_remove) + self._add_updated_hyperparameters(node, base_hps, additional_hps) + + node.keywords.append(ast.keyword(arg="script_mode", value=ast.NameConstant(value=False))) + + def _hyperparameter_key_for_param(self, arg): + """Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter.""" + name = "sagemaker_requirements" if arg == "requirements_file" else arg + return ast.Str(s=name) + + def _remove_keywords(self, node, keywords): + """Removes the keywords from the ``ast.Call`` node.""" + for kw in keywords: + node.keywords.remove(kw) + + def _add_updated_hyperparameters(self, node, base_hps, additional_hps): + """Combines and adds the hyperparameters to the ``ast.Call`` node's keywords.""" + base_hps.update(additional_hps) + updated_hp_keyword = self._to_ast_keyword(base_hps) + + if updated_hp_keyword: + node.keywords.append(updated_hp_keyword) + + def _to_ast_keyword(self, hps): + """Returns an ``ast.keyword`` for the ``hyperparameters`` kwarg if there are any.""" + if hps: + keys, values = zip(*six.iteritems(hps)) + return ast.keyword(arg="hyperparameters", value=ast.Dict(keys=keys, values=values)) + + return None diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py new file mode 100644 index 0000000000..35ec383758 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py @@ -0,0 +1,162 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import sys + +import pasta +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode + + +@pytest.fixture(autouse=True) +def skip_if_py2(): + # Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. + if sys.version_info.major < 3: + pytest.skip("v2 migration script doesn't support Python 2.") + + +def test_node_should_be_modified_tf_constructor_legacy_mode(): + tf_legacy_mode_constructors = ( + "TensorFlow(script_mode=False)", + "TensorFlow(script_mode=None)", + "TensorFlow(py_version='py2')", + "TensorFlow()", + "sagemaker.tensorflow.TensorFlow(script_mode=False)", + "sagemaker.tensorflow.TensorFlow(script_mode=None)", + "sagemaker.tensorflow.TensorFlow(py_version='py2')", + "sagemaker.tensorflow.TensorFlow()", + ) + + modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() + + for constructor in tf_legacy_mode_constructors: + node = _ast_call(constructor) + assert modifier.node_should_be_modified(node) is True + + +def test_node_should_be_modified_tf_constructor_script_mode(): + tf_script_mode_constructors = ( + "TensorFlow(script_mode=True)", + "TensorFlow(py_version='py3')", + "TensorFlow(py_version='py37')", + "TensorFlow(py_version='py3', script_mode=False)", + "sagemaker.tensorflow.TensorFlow(script_mode=True)", + "sagemaker.tensorflow.TensorFlow(py_version='py3')", + "sagemaker.tensorflow.TensorFlow(py_version='py37')", + "sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)", + ) + + modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() + + for constructor in tf_script_mode_constructors: + node = _ast_call(constructor) + assert modifier.node_should_be_modified(node) is False + + +def test_node_should_be_modified_random_function_call(): + node = _ast_call("MXNet(py_version='py3')") + modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() + assert modifier.node_should_be_modified(node) is False + + +def test_modify_node_set_script_mode_false(): + tf_constructors = ( + "TensorFlow()", + "TensorFlow(script_mode=False)", + "TensorFlow(script_mode=None)", + ) + modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() + + for constructor in tf_constructors: + node = _ast_call(constructor) + modifier.modify_node(node) + assert "TensorFlow(script_mode=False)" == pasta.dump(node) + + +def test_modify_node_set_hyperparameters(): + tf_constructor = """TensorFlow( + checkpoint_path='s3://foo/bar', + training_steps=100, + evaluation_steps=10, + requirements_file='source/requirements.txt', + )""" + + node = _ast_call(tf_constructor) + modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() + modifier.modify_node(node) + + expected_hyperparameters = { + "checkpoint_path": "s3://foo/bar", + "evaluation_steps": 10, + "sagemaker_requirements": "source/requirements.txt", + "training_steps": 100, + } + + assert expected_hyperparameters == _hyperparameters_from_node(node) + + +def test_modify_node_preserve_other_hyperparameters(): + tf_constructor = """sagemaker.tensorflow.TensorFlow( + training_steps=100, + evaluation_steps=10, + requirements_file='source/requirements.txt', + hyperparameters={'optimizer': 'sgd', 'lr': 0.1, 'checkpoint_path': 's3://foo/bar'}, + )""" + + node = _ast_call(tf_constructor) + modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() + modifier.modify_node(node) + + expected_hyperparameters = { + "optimizer": "sgd", + "lr": 0.1, + "checkpoint_path": "s3://foo/bar", + "evaluation_steps": 10, + "sagemaker_requirements": "source/requirements.txt", + "training_steps": 100, + } + + assert expected_hyperparameters == _hyperparameters_from_node(node) + + +def test_modify_node_prefer_param_over_hyperparameter(): + tf_constructor = """sagemaker.tensorflow.TensorFlow( + training_steps=100, + requirements_file='source/requirements.txt', + hyperparameters={'training_steps': 10, 'sagemaker_requirements': 'foo.txt'}, + )""" + + node = _ast_call(tf_constructor) + modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() + modifier.modify_node(node) + + expected_hyperparameters = { + "sagemaker_requirements": "source/requirements.txt", + "training_steps": 100, + } + + assert expected_hyperparameters == _hyperparameters_from_node(node) + + +def _hyperparameters_from_node(node): + for kw in node.keywords: + if kw.arg == "hyperparameters": + keys = [k.s for k in kw.value.keys] + values = [getattr(v, v._fields[0]) for v in kw.value.values] + return dict(zip(keys, values)) + + +def _ast_call(code): + return pasta.parse(code).body[0].value