@@ -54,6 +54,7 @@ def node_should_be_modified(self, node):
5454
5555 - ``TensorFlow``
5656 - ``sagemaker.tensorflow.TensorFlow``
57+ - ``sagemaker.tensorflow.estimator.TensorFlow``
5758
5859 Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
5960 and (2) if ``py_version`` is ``py2`` or not specified.
@@ -68,27 +69,35 @@ def node_should_be_modified(self, node):
6869 return self ._is_tf_constructor (node ) and self ._is_legacy_mode (node )
6970
7071 def _is_tf_constructor (self , node ):
71- """Checks if the ``ast.Call`` node represents a call of the form
72- ``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
72+ """Checks if the ``ast.Call`` node represents a call of the form ``TensorFlow``,
73+ ``sagemaker.tensorflow. TensorFlow``, or ``sagemaker.tensorflow.estimator .TensorFlow``.
7374 """
7475 # Check for TensorFlow()
7576 if isinstance (node .func , ast .Name ):
7677 return node .func .id == "TensorFlow"
7778
79+ # Check for something.that.ends.with.TensorFlow()
80+ if not (isinstance (node .func , ast .Attribute ) and node .func .attr == "TensorFlow" ):
81+ return False
82+
83+ # Check for sagemaker.tensorflow.estimator.TensorFlow()
84+ if isinstance (node .func .value , ast .Attribute ) and node .func .value .attr == "estimator" :
85+ return self ._is_in_tensorflow_module (node .func .value )
86+
7887 # Check for sagemaker.tensorflow.TensorFlow()
79- ends_with_tensorflow_constructor = (
80- isinstance (node .func , ast .Attribute ) and node .func .attr == "TensorFlow"
81- )
88+ return self ._is_in_tensorflow_module (node .func )
8289
83- is_in_tensorflow_module = (
84- isinstance (node .func .value , ast .Attribute )
85- and node .func .value .attr == "tensorflow"
86- and isinstance (node .func .value .value , ast .Name )
87- and node .func .value .value .id == "sagemaker"
90+ def _is_in_tensorflow_module (self , node ):
91+ """Checks if the node is an ``ast.Attribute`` that represents the
92+ ``sagemaker.tensorflow`` module.
93+ """
94+ return (
95+ isinstance (node .value , ast .Attribute )
96+ and node .value .attr == "tensorflow"
97+ and isinstance (node .value .value , ast .Name )
98+ and node .value .value .id == "sagemaker"
8899 )
89100
90- return ends_with_tensorflow_constructor and is_in_tensorflow_module
91-
92101 def _is_legacy_mode (self , node ):
93102 """Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
94103 script_mode = False
0 commit comments