diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index ba014d085f..5932b27367 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -17,6 +17,9 @@ from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier +FRAMEWORK_ARG = "framework_version" +PY_ARG = "py_version" + FRAMEWORK_DEFAULTS = { "Chainer": "4.1.0", "MXNet": "1.2.0", @@ -25,10 +28,11 @@ "TensorFlow": "1.11.0", } -FRAMEWORKS = list(FRAMEWORK_DEFAULTS.keys()) +FRAMEWORK_CLASSES = list(FRAMEWORK_DEFAULTS.keys()) +MODEL_CLASSES = ["{}Model".format(fw) for fw in FRAMEWORK_CLASSES] + # TODO: check for sagemaker.tensorflow.serving.Model -FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS] -FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS] +FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES] FRAMEWORK_SUBMODULES = ("model", "estimator") @@ -39,7 +43,8 @@ class FrameworkVersionEnforcer(Modifier): def node_should_be_modified(self, node): """Checks if the ast.Call node instantiates a framework estimator or model, - but doesn't specify the ``framework_version`` parameter. + but doesn't specify the ``framework_version`` and ``py_version`` parameter, + as appropriate. This looks for the following formats: @@ -56,49 +61,12 @@ def node_should_be_modified(self, node): bool: If the ``ast.Call`` is instantiating a framework class that should specify ``framework_version``, but doesn't. """ - if self._is_framework_constructor(node): - return not self._fw_version_in_keywords(node) + if _is_named_constructor(node, FRAMEWORK_CLASSES): + return _version_args_needed(node, "image_name") - return False + if _is_named_constructor(node, MODEL_CLASSES): + return _version_args_needed(node, "image") - def _is_framework_constructor(self, node): - """Checks if the ``ast.Call`` node represents a call of the form - or sagemaker... - """ - # Check for call - if isinstance(node.func, ast.Name): - return node.func.id in FRAMEWORK_CLASSES - - # Check for something.that.ends.with.. call - if not (isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES): - return False - - # Check for sagemaker... call - if ( - isinstance(node.func.value, ast.Attribute) - and node.func.value.attr in FRAMEWORK_SUBMODULES - ): - return self._is_in_framework_module(node.func.value) - - # Check for sagemaker.. call - return self._is_in_framework_module(node.func) - - def _is_in_framework_module(self, node): - """Checks if the node is an ``ast.Attribute`` that represents a - ``sagemaker.`` module. - """ - return ( - isinstance(node.value, ast.Attribute) - and node.value.attr in FRAMEWORK_MODULES - and isinstance(node.value.value, ast.Name) - and node.value.value.id == "sagemaker" - ) - - def _fw_version_in_keywords(self, node): - """Checks if the ``ast.Call`` node's keywords contain ``framework_version``.""" - for kw in node.keywords: - if kw.arg == "framework_version" and kw.value: - return True return False def modify_node(self, node): @@ -112,30 +80,146 @@ def modify_node(self, node): - SKLearn: "0.20.0" - TensorFlow: "1.11.0" + The ``py_version`` value is determined by the framework, framework_version, and if it is a + model, whether the model accepts a py_version + Args: node (ast.Call): a node that represents the constructor of a framework class. """ - framework = self._framework_name_from_node(node) - node.keywords.append( - ast.keyword(arg="framework_version", value=ast.Str(s=FRAMEWORK_DEFAULTS[framework])) - ) + framework, is_model = _framework_from_node(node) - def _framework_name_from_node(self, node): - """Retrieves the framework name based on the function call. + # if framework_version is not supplied, get default and append keyword + framework_version = _arg_value(node, FRAMEWORK_ARG) + if framework_version is None: + framework_version = FRAMEWORK_DEFAULTS[framework] + node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version))) - Args: - node (ast.Call): a node that represents the constructor of a framework class. - This can represent either or sagemaker... + # if py_version is not supplied, get a conditional default, and if not None, append keyword + py_version = _arg_value(node, PY_ARG) + if py_version is None: + py_version = _py_version_defaults(framework, framework_version, is_model) + if py_version: + node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version))) - Returns: - str: the (capitalized) framework name. - """ - if isinstance(node.func, ast.Name): - framework = node.func.id - elif isinstance(node.func, ast.Attribute): - framework = node.func.attr - if framework.endswith("Model"): - framework = framework[: framework.find("Model")] +def _py_version_defaults(framework, framework_version, is_model=False): + """Gets the py_version required for the framework_version and if it's a model + + Args: + framework (str): name of the framework + framework_version (str): version of the framework + is_model (bool): whether it is a constructor for a model or not + + Returns: + str: the default py version, as appropriate. None if no default py_version + """ + if framework in ("Chainer", "PyTorch"): + return "py3" + if framework == "SKLearn" and not is_model: + return "py3" + if framework == "MXNet": + return "py2" + if framework == "TensorFlow" and not is_model: + return _tf_py_version_default(framework_version) + return None + + +def _tf_py_version_default(framework_version): + """Gets the py_version default based on framework_version for TensorFlow.""" + if not framework_version: + return "py2" + version = [int(s) for s in framework_version.split(".")] + if version < [1, 12]: + return "py2" + if version < [2, 2]: + return "py3" + return "py37" + + +def _framework_from_node(node): + """Retrieves the framework class name based on the function call, and if it was a model + + Args: + node (ast.Call): a node that represents the constructor of a framework class. + This can represent either or sagemaker... + + Returns: + str, bool: the (capitalized) framework class name, and if it is a model class + """ + if isinstance(node.func, ast.Name): + framework = node.func.id + elif isinstance(node.func, ast.Attribute): + framework = node.func.attr + else: + framework = "" + + is_model = framework.endswith("Model") + if is_model: + framework = framework[: framework.find("Model")] + + return framework, is_model + + +def _is_named_constructor(node, names): + """Checks if the ``ast.Call`` node represents a call to particular named constructors. + + Forms that qualify are either or sagemaker.. + where belongs to the list of names passed in. + """ + # Check for call from particular names of constructors + if isinstance(node.func, ast.Name): + return node.func.id in names + + # Check for something.that.ends.with.. call for Framework in names + if not (isinstance(node.func, ast.Attribute) and node.func.attr in names): + return False + + # Check for sagemaker... call + if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES: + return _is_in_framework_module(node.func.value) + + # Check for sagemaker.. call + return _is_in_framework_module(node.func) + + +def _is_in_framework_module(node): + """Checks if node is an ``ast.Attribute`` representing a ``sagemaker.`` module.""" + return ( + isinstance(node.value, ast.Attribute) + and node.value.attr in FRAMEWORK_MODULES + and isinstance(node.value.value, ast.Name) + and node.value.value.id == "sagemaker" + ) + + +def _version_args_needed(node, image_arg): + """Determines if image_arg or version_arg was supplied + + Applies similar logic as ``validate_version_or_image_args`` + """ + # if image_arg is present, no need to supply version arguments + image_name = _arg_value(node, image_arg) + if image_name: + return False + + # if framework_version is None, need args + framework_version = _arg_value(node, FRAMEWORK_ARG) + if framework_version is None: + return True + + # check if we expect py_version and we don't get it -- framework and model dependent + framework, is_model = _framework_from_node(node) + expecting_py_version = _py_version_defaults(framework, framework_version, is_model) + if expecting_py_version: + py_version = _arg_value(node, PY_ARG) + return py_version is None + + return False + - return framework +def _arg_value(node, arg): + """Gets the value associated with the arg keyword, if present""" + for kw in node.keywords: + if kw.arg == arg and kw.value: + return kw.value.s + return None diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py index ebddde2a7b..7cbb966533 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py @@ -21,6 +21,91 @@ from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call +class Template: + """Essentially a data class with a parametrized format method + + Helper to interpolate various combinations of framework_version, py_version, and image + as expected by framework and model classes + + TODO: use attrs package and eliminate the boilerplate... + """ + + def __init__( + self, framework, framework_version, py_version, py_version_for_model=True, + ): + self.framework = framework + self.framework_version = framework_version + self.py_version = py_version + self.py_version_for_model = py_version_for_model + + def constructors(self, versions=False, image=False): + return self._frameworks(versions, image) + self._models(versions, image) + + def _templates(self, model=False): + module = self.framework.lower() + submodule = "model" if model else "estimator" + suffix = "Model" if model else "" + classname = "{framework}{suffix}".format(framework=self.framework, suffix=suffix) + templates = ( + "{classname}({{}})", + "sagemaker.{module}.{classname}({{}})", + "sagemaker.{module}.{submodule}.{classname}({{}})", + ) + return tuple( + template.format(classname=classname, module=module, submodule=submodule) + for template in templates + ) + + def _frameworks(self, versions=False, image=False): + keywords = dict() + if image: + keywords["image_name"] = "my:image" + if versions: + keywords["framework_version"] = self.framework_version + keywords["py_version"] = self.py_version + return _format_templates(keywords, self._templates()) + + def _models(self, versions=False, image=False): + keywords = dict() + if image: + keywords["image"] = "my:image" + if versions: + keywords["framework_version"] = self.framework_version + if self.py_version_for_model: + keywords["py_version"] = self.py_version + return _format_templates(keywords, self._templates(model=True)) + + +def _format_templates(keywords, templates): + args = ", ".join( + "{key}='{value}'".format(key=key, value=value) for key, value in keywords.items() + ) + return [template.format(args) for template in templates] + + +TEMPLATES = [ + Template( + framework="TensorFlow", + framework_version="1.11.0", + py_version="py2", + py_version_for_model=False, + ), + Template(framework="MXNet", framework_version="1.2.0", py_version="py2",), + Template(framework="Chainer", framework_version="4.1.0", py_version="py3",), + Template(framework="PyTorch", framework_version="0.4.0", py_version="py3",), + Template( + framework="SKLearn", + framework_version="0.20.0", + py_version="py3", + py_version_for_model=False, + ), +] + + +def constructors(versions=False, image=False): + return [ctr for template in TEMPLATES for ctr in template.constructors(versions, image)] + + @pytest.fixture(autouse=True) def skip_if_py2(): # Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. @@ -28,128 +113,68 @@ def skip_if_py2(): pytest.skip("v2 migration script doesn't support Python 2.") -def test_node_should_be_modified_fw_constructor_no_fw_version(): - fw_constructors = ( - "TensorFlow()", - "sagemaker.tensorflow.TensorFlow()", - "sagemaker.tensorflow.estimator.TensorFlow()", - "TensorFlowModel()", - "sagemaker.tensorflow.TensorFlowModel()", - "sagemaker.tensorflow.model.TensorFlowModel()", - "MXNet()", - "sagemaker.mxnet.MXNet()", - "sagemaker.mxnet.estimator.MXNet()", - "MXNetModel()", - "sagemaker.mxnet.MXNetModel()", - "sagemaker.mxnet.model.MXNetModel()", - "Chainer()", - "sagemaker.chainer.Chainer()", - "sagemaker.chainer.estimator.Chainer()", - "ChainerModel()", - "sagemaker.chainer.ChainerModel()", - "sagemaker.chainer.model.ChainerModel()", - "PyTorch()", - "sagemaker.pytorch.PyTorch()", - "sagemaker.pytorch.estimator.PyTorch()", - "PyTorchModel()", - "sagemaker.pytorch.PyTorchModel()", - "sagemaker.pytorch.model.PyTorchModel()", - "SKLearn()", - "sagemaker.sklearn.SKLearn()", - "sagemaker.sklearn.estimator.SKLearn()", - "SKLearnModel()", - "sagemaker.sklearn.SKLearnModel()", - "sagemaker.sklearn.model.SKLearnModel()", - ) +@pytest.fixture +def constructors_empty(): + return constructors() - modifier = framework_version.FrameworkVersionEnforcer() - for constructor in fw_constructors: - node = ast_call(constructor) - assert modifier.node_should_be_modified(node) is True - - -def test_node_should_be_modified_fw_constructor_with_fw_version(): - fw_constructors = ( - "TensorFlow(framework_version='2.2')", - "sagemaker.tensorflow.TensorFlow(framework_version='2.2')", - "sagemaker.tensorflow.estimator.TensorFlow(framework_version='2.2')", - "TensorFlowModel(framework_version='1.10')", - "sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')", - "sagemaker.tensorflow.model.TensorFlowModel(framework_version='1.10')", - "MXNet(framework_version='1.6')", - "sagemaker.mxnet.MXNet(framework_version='1.6')", - "sagemaker.mxnet.estimator.MXNet(framework_version='1.6')", - "MXNetModel(framework_version='1.6')", - "sagemaker.mxnet.MXNetModel(framework_version='1.6')", - "sagemaker.mxnet.model.MXNetModel(framework_version='1.6')", - "PyTorch(framework_version='1.4')", - "sagemaker.pytorch.PyTorch(framework_version='1.4')", - "sagemaker.pytorch.estimator.PyTorch(framework_version='1.4')", - "PyTorchModel(framework_version='1.4')", - "sagemaker.pytorch.PyTorchModel(framework_version='1.4')", - "sagemaker.pytorch.model.PyTorchModel(framework_version='1.4')", - "Chainer(framework_version='5.0')", - "sagemaker.chainer.Chainer(framework_version='5.0')", - "sagemaker.chainer.estimator.Chainer(framework_version='5.0')", - "ChainerModel(framework_version='5.0')", - "sagemaker.chainer.ChainerModel(framework_version='5.0')", - "sagemaker.chainer.model.ChainerModel(framework_version='5.0')", - "SKLearn(framework_version='0.20.0')", - "sagemaker.sklearn.SKLearn(framework_version='0.20.0')", - "sagemaker.sklearn.estimator.SKLearn(framework_version='0.20.0')", - "SKLearnModel(framework_version='0.20.0')", - "sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')", - "sagemaker.sklearn.model.SKLearnModel(framework_version='0.20.0')", - ) +@pytest.fixture +def constructors_with_versions(): + return constructors(versions=True) - modifier = framework_version.FrameworkVersionEnforcer() - for constructor in fw_constructors: - node = ast_call(constructor) - assert modifier.node_should_be_modified(node) is False +@pytest.fixture +def constructors_with_image(): + return constructors(image=True) -def test_node_should_be_modified_random_function_call(): - node = ast_call("sagemaker.session.Session()") - modifier = framework_version.FrameworkVersionEnforcer() - assert modifier.node_should_be_modified(node) is False +@pytest.fixture +def constructors_with_both(): + return constructors(versions=True, image=True) -def test_modify_node_tf(): - _test_modify_node("TensorFlow", "1.11.0") +def _test_node_should_be_modified(ctrs, should_modify=True): + modifier = framework_version.FrameworkVersionEnforcer() + for ctr in ctrs: + node = ast_call(ctr) + if should_modify: + assert modifier.node_should_be_modified(node), "{} wasn't modified.".format(ctr) + else: + assert not modifier.node_should_be_modified(node), "{} was modified.".format(ctr) -def test_modify_node_mx(): - _test_modify_node("MXNet", "1.2.0") +def test_node_should_be_modified_empty(constructors_empty): + _test_node_should_be_modified(constructors_empty, should_modify=True) -def test_modify_node_chainer(): - _test_modify_node("Chainer", "4.1.0") +def test_node_should_be_modified_with_versions(constructors_with_versions): + _test_node_should_be_modified(constructors_with_versions, should_modify=False) -def test_modify_node_pt(): - _test_modify_node("PyTorch", "0.4.0") +def test_node_should_be_modified_with_image(constructors_with_image): + _test_node_should_be_modified(constructors_with_image, should_modify=False) -def test_modify_node_sklearn(): - _test_modify_node("SKLearn", "0.20.0") +def test_node_should_be_modified_random_function_call(): + _test_node_should_be_modified(["sagemaker.session.Session()"], should_modify=False) -def _test_modify_node(framework, default_version): +def _test_modify_node(ctrs_before, ctrs_expected): modifier = framework_version.FrameworkVersionEnforcer() - - classes = ( - "{}".format(framework), - "sagemaker.{}.{}".format(framework.lower(), framework), - "sagemaker.{}.estimator.{}".format(framework.lower(), framework), - "{}Model".format(framework), - "sagemaker.{}.{}Model".format(framework.lower(), framework), - "sagemaker.{}.model.{}Model".format(framework.lower(), framework), - ) - for cls in classes: - node = ast_call("{}()".format(cls)) + for before, expected in zip(ctrs_before, ctrs_expected): + node = ast_call(before) modifier.modify_node(node) + # NOTE: this type of equality with pasta depends on ordering of args... + assert expected == pasta.dump(node) + + +def test_modify_node_empty(constructors_empty, constructors_with_versions): + _test_modify_node(constructors_empty, constructors_with_versions) + + +def test_modify_node_with_versions(constructors_with_versions): + _test_modify_node(constructors_with_versions, constructors_with_versions) + - expected_result = "{}(framework_version='{}')".format(cls, default_version) - assert expected_result == pasta.dump(node) +def test_modify_node_with_image(constructors_with_image, constructors_with_both): + _test_modify_node(constructors_with_image, constructors_with_both)