From b2bc3a392b3866e7b6afaa870c95b562f6f8a59b Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Fri, 5 Jun 2020 14:17:09 -0700 Subject: [PATCH] fix: look for 'sagemaker.. module in v2 migration tool --- .../v2/modifiers/framework_version.py | 32 ++++++--- .../v2/modifiers/test_framework_version.py | 65 ++++++++++--------- 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index 848625a58a..ba014d085f 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -29,6 +29,7 @@ # TODO: check for sagemaker.tensorflow.serving.Model FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS] FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS] +FRAMEWORK_SUBMODULES = ("model", "estimator") class FrameworkVersionEnforcer(Modifier): @@ -68,19 +69,30 @@ def _is_framework_constructor(self, node): if isinstance(node.func, ast.Name): return node.func.id in FRAMEWORK_CLASSES - # Check for sagemaker.. call - ends_with_framework_constructor = ( - isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES - ) + # Check for something.that.ends.with.. call + if not (isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES): + return False - is_in_framework_module = ( + # Check for sagemaker... call + if ( isinstance(node.func.value, ast.Attribute) - and node.func.value.attr in FRAMEWORK_MODULES - and isinstance(node.func.value.value, ast.Name) - and node.func.value.value.id == "sagemaker" - ) + and node.func.value.attr in FRAMEWORK_SUBMODULES + ): + return self._is_in_framework_module(node.func.value) - return ends_with_framework_constructor and is_in_framework_module + # Check for sagemaker.. call + return self._is_in_framework_module(node.func) + + def _is_in_framework_module(self, node): + """Checks if the node is an ``ast.Attribute`` that represents a + ``sagemaker.`` module. + """ + return ( + isinstance(node.value, ast.Attribute) + and node.value.attr in FRAMEWORK_MODULES + and isinstance(node.value.value, ast.Name) + and node.value.value.id == "sagemaker" + ) def _fw_version_in_keywords(self, node): """Checks if the ``ast.Call`` node's keywords contain ``framework_version``.""" diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py index 796472a14a..ebddde2a7b 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py @@ -32,24 +32,34 @@ def test_node_should_be_modified_fw_constructor_no_fw_version(): fw_constructors = ( "TensorFlow()", "sagemaker.tensorflow.TensorFlow()", + "sagemaker.tensorflow.estimator.TensorFlow()", "TensorFlowModel()", "sagemaker.tensorflow.TensorFlowModel()", + "sagemaker.tensorflow.model.TensorFlowModel()", "MXNet()", "sagemaker.mxnet.MXNet()", + "sagemaker.mxnet.estimator.MXNet()", "MXNetModel()", "sagemaker.mxnet.MXNetModel()", + "sagemaker.mxnet.model.MXNetModel()", "Chainer()", "sagemaker.chainer.Chainer()", + "sagemaker.chainer.estimator.Chainer()", "ChainerModel()", "sagemaker.chainer.ChainerModel()", + "sagemaker.chainer.model.ChainerModel()", "PyTorch()", "sagemaker.pytorch.PyTorch()", + "sagemaker.pytorch.estimator.PyTorch()", "PyTorchModel()", "sagemaker.pytorch.PyTorchModel()", + "sagemaker.pytorch.model.PyTorchModel()", "SKLearn()", "sagemaker.sklearn.SKLearn()", + "sagemaker.sklearn.estimator.SKLearn()", "SKLearnModel()", "sagemaker.sklearn.SKLearnModel()", + "sagemaker.sklearn.model.SKLearnModel()", ) modifier = framework_version.FrameworkVersionEnforcer() @@ -63,24 +73,34 @@ def test_node_should_be_modified_fw_constructor_with_fw_version(): fw_constructors = ( "TensorFlow(framework_version='2.2')", "sagemaker.tensorflow.TensorFlow(framework_version='2.2')", + "sagemaker.tensorflow.estimator.TensorFlow(framework_version='2.2')", "TensorFlowModel(framework_version='1.10')", "sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')", + "sagemaker.tensorflow.model.TensorFlowModel(framework_version='1.10')", "MXNet(framework_version='1.6')", "sagemaker.mxnet.MXNet(framework_version='1.6')", + "sagemaker.mxnet.estimator.MXNet(framework_version='1.6')", "MXNetModel(framework_version='1.6')", "sagemaker.mxnet.MXNetModel(framework_version='1.6')", + "sagemaker.mxnet.model.MXNetModel(framework_version='1.6')", "PyTorch(framework_version='1.4')", "sagemaker.pytorch.PyTorch(framework_version='1.4')", + "sagemaker.pytorch.estimator.PyTorch(framework_version='1.4')", "PyTorchModel(framework_version='1.4')", "sagemaker.pytorch.PyTorchModel(framework_version='1.4')", + "sagemaker.pytorch.model.PyTorchModel(framework_version='1.4')", "Chainer(framework_version='5.0')", "sagemaker.chainer.Chainer(framework_version='5.0')", + "sagemaker.chainer.estimator.Chainer(framework_version='5.0')", "ChainerModel(framework_version='5.0')", "sagemaker.chainer.ChainerModel(framework_version='5.0')", + "sagemaker.chainer.model.ChainerModel(framework_version='5.0')", "SKLearn(framework_version='0.20.0')", "sagemaker.sklearn.SKLearn(framework_version='0.20.0')", + "sagemaker.sklearn.estimator.SKLearn(framework_version='0.20.0')", "SKLearnModel(framework_version='0.20.0')", "sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')", + "sagemaker.sklearn.model.SKLearnModel(framework_version='0.20.0')", ) modifier = framework_version.FrameworkVersionEnforcer() @@ -97,51 +117,36 @@ def test_node_should_be_modified_random_function_call(): def test_modify_node_tf(): - classes = ( - "TensorFlow" "sagemaker.tensorflow.TensorFlow", - "TensorFlowModel", - "sagemaker.tensorflow.TensorFlowModel", - ) - _test_modify_node(classes, "1.11.0") + _test_modify_node("TensorFlow", "1.11.0") def test_modify_node_mx(): - classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel") - _test_modify_node(classes, "1.2.0") + _test_modify_node("MXNet", "1.2.0") def test_modify_node_chainer(): - classes = ( - "Chainer", - "sagemaker.chainer.Chainer", - "ChainerModel", - "sagemaker.chainer.ChainerModel", - ) - _test_modify_node(classes, "4.1.0") + _test_modify_node("Chainer", "4.1.0") def test_modify_node_pt(): - classes = ( - "PyTorch", - "sagemaker.pytorch.PyTorch", - "PyTorchModel", - "sagemaker.pytorch.PyTorchModel", - ) - _test_modify_node(classes, "0.4.0") + _test_modify_node("PyTorch", "0.4.0") def test_modify_node_sklearn(): - classes = ( - "SKLearn", - "sagemaker.sklearn.SKLearn", - "SKLearnModel", - "sagemaker.sklearn.SKLearnModel", - ) - _test_modify_node(classes, "0.20.0") + _test_modify_node("SKLearn", "0.20.0") -def _test_modify_node(classes, default_version): +def _test_modify_node(framework, default_version): modifier = framework_version.FrameworkVersionEnforcer() + + classes = ( + "{}".format(framework), + "sagemaker.{}.{}".format(framework.lower(), framework), + "sagemaker.{}.estimator.{}".format(framework.lower(), framework), + "{}Model".format(framework), + "sagemaker.{}.{}Model".format(framework.lower(), framework), + "sagemaker.{}.model.{}Model".format(framework.lower(), framework), + ) for cls in classes: node = ast_call("{}()".format(cls)) modifier.modify_node(node)