diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py index 0b531a3922..554150b253 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py @@ -54,6 +54,7 @@ def node_should_be_modified(self, node): - ``TensorFlow`` - ``sagemaker.tensorflow.TensorFlow`` + - ``sagemaker.tensorflow.estimator.TensorFlow`` Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified, and (2) if ``py_version`` is ``py2`` or not specified. @@ -68,27 +69,35 @@ def node_should_be_modified(self, node): return self._is_tf_constructor(node) and self._is_legacy_mode(node) def _is_tf_constructor(self, node): - """Checks if the ``ast.Call`` node represents a call of the form - ``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``. + """Checks if the ``ast.Call`` node represents a call of the form ``TensorFlow``, + ``sagemaker.tensorflow.TensorFlow``, or ``sagemaker.tensorflow.estimator.TensorFlow``. """ # Check for TensorFlow() if isinstance(node.func, ast.Name): return node.func.id == "TensorFlow" + # Check for something.that.ends.with.TensorFlow() + if not (isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"): + return False + + # Check for sagemaker.tensorflow.estimator.TensorFlow() + if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == "estimator": + return self._is_in_tensorflow_module(node.func.value) + # Check for sagemaker.tensorflow.TensorFlow() - ends_with_tensorflow_constructor = ( - isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow" - ) + return self._is_in_tensorflow_module(node.func) - is_in_tensorflow_module = ( - isinstance(node.func.value, ast.Attribute) - and node.func.value.attr == "tensorflow" - and isinstance(node.func.value.value, ast.Name) - and node.func.value.value.id == "sagemaker" + def _is_in_tensorflow_module(self, node): + """Checks if the node is an ``ast.Attribute`` that represents the + ``sagemaker.tensorflow`` module. + """ + return ( + isinstance(node.value, ast.Attribute) + and node.value.attr == "tensorflow" + and isinstance(node.value.value, ast.Name) + and node.value.value.id == "sagemaker" ) - return ends_with_tensorflow_constructor and is_in_tensorflow_module - def _is_legacy_mode(self, node): """Checks if the ``ast.Call`` node's keywords signal using legacy mode.""" script_mode = False diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py index ac5f4015ea..affb2940c4 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py @@ -42,6 +42,10 @@ def test_node_should_be_modified_tf_constructor_legacy_mode(): "sagemaker.tensorflow.TensorFlow(script_mode=None)", "sagemaker.tensorflow.TensorFlow(py_version='py2')", "sagemaker.tensorflow.TensorFlow()", + "sagemaker.tensorflow.estimator.TensorFlow(script_mode=False)", + "sagemaker.tensorflow.estimator.TensorFlow(script_mode=None)", + "sagemaker.tensorflow.estimator.TensorFlow(py_version='py2')", + "sagemaker.tensorflow.estimator.TensorFlow()", ) modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() @@ -61,6 +65,10 @@ def test_node_should_be_modified_tf_constructor_script_mode(): "sagemaker.tensorflow.TensorFlow(py_version='py3')", "sagemaker.tensorflow.TensorFlow(py_version='py37')", "sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)", + "sagemaker.tensorflow.estimator.TensorFlow(script_mode=True)", + "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3')", + "sagemaker.tensorflow.estimator.TensorFlow(py_version='py37')", + "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=False)", ) modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()