From aab67792fe065375da2005b4918c7b0b3bd4bfc1 Mon Sep 17 00:00:00 2001 From: Eric Johnson <65414824+metrizable@users.noreply.github.com> Date: Tue, 30 Jun 2020 11:14:41 -0700 Subject: [PATCH] infra: add cli modifier for RealTimePredictor and derived classes --- .../cli/compatibility/v2/ast_transformer.py | 6 +- .../compatibility/v2/modifiers/__init__.py | 1 + .../compatibility/v2/modifiers/predictors.py | 146 ++++++++++++++++++ .../v2/modifiers/test_predictors.py | 128 +++++++++++++++ 4 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 src/sagemaker/cli/compatibility/v2/modifiers/predictors.py create mode 100644 tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_predictors.py diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 2fb35f8c56..18fbe4aae0 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -18,6 +18,7 @@ from sagemaker.cli.compatibility.v2 import modifiers FUNCTION_CALL_MODIFIERS = [ + modifiers.predictors.PredictorConstructorRefactor(), modifiers.framework_version.FrameworkVersionEnforcer(), modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(), modifiers.tf_legacy_mode.TensorBoardParameterRemover(), @@ -28,7 +29,10 @@ IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()] -IMPORT_FROM_MODIFIERS = [modifiers.tfs.TensorFlowServingImportFromRenamer()] +IMPORT_FROM_MODIFIERS = [ + modifiers.predictors.PredictorImportFromRenamer(), + modifiers.tfs.TensorFlowServingImportFromRenamer(), +] class ASTTransformer(ast.NodeTransformer): diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py index 0b97d45868..d9d8600d7f 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py @@ -17,6 +17,7 @@ airflow, deprecated_params, framework_version, + predictors, tf_legacy_mode, tfs, ) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/predictors.py b/src/sagemaker/cli/compatibility/v2/modifiers/predictors.py new file mode 100644 index 0000000000..43b60bedef --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/predictors.py @@ -0,0 +1,146 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes to modify Predictor code to be compatible +with version 2.0 and later of the SageMaker Python SDK. +""" +from __future__ import absolute_import + +import ast + +from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier + +BASE_PREDICTOR = "RealTimePredictor" +PREDICTORS = { + "FactorizationMachinesPredictor": ("sagemaker", "sagemaker.amazon.factorization_machines"), + "IPInsightsPredictor": ("sagemaker", "sagemaker.amazon.ipinsights"), + "KMeansPredictor": ("sagemaker", "sagemaker.amazon.kmeans"), + "KNNPredictor": ("sagemaker", "sagemaker.amazon.knn"), + "LDAPredictor": ("sagemaker", "sagemaker.amazon.lda"), + "LinearLearnerPredictor": ("sagemaker", "sagemaker.amazon.linear_learner"), + "NTMPredictor": ("sagemaker", "sagemaker.amazon.ntm"), + "PCAPredictor": ("sagemaker", "sagemaker.amazon.pca"), + "RandomCutForestPredictor": ("sagemaker", "sagemaker.amazon.randomcutforest"), + "RealTimePredictor": ("sagemaker", "sagemaker.predictor"), + "SparkMLPredictor": ("sagemaker.sparkml", "sagemaker.sparkml.model"), +} + + +class PredictorConstructorRefactor(Modifier): + """A class to refactor *Predictor class and refactor endpoint attribute.""" + + def node_should_be_modified(self, node): + """Checks if the ``ast.Call`` node instantiates a class of interest. + + This looks for the following calls: + + - ``sagemaker...`` + - ``sagemaker..`` + - ```` + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` instantiates a class of interest. + """ + return any(_matching(node, name, namespaces) for name, namespaces in PREDICTORS.items()) + + def modify_node(self, node): + """Modifies the ``ast.Call`` node to call ``Predictor`` instead. + + Also renames ``endpoint`` attribute to ``endpoint_name``. + + Args: + node (ast.Call): a node that represents a *Predictor constructor. + """ + _rename_class(node) + _rename_endpoint(node) + + +def _matching(node, name, namespaces): + """Determines if the node matches the constructor name in the right namespace""" + if _matching_name(node, name): + return True + + if not _matching_attr(node, name): + return False + + return any(_matching_namespace(node, namespace) for namespace in namespaces) + + +def _matching_name(node, name): + """Determines if the node is an ast.Name node with a matching name""" + return isinstance(node.func, ast.Name) and node.func.id == name + + +def _matching_attr(node, name): + """Determines if the node is an ast.Attribute node with a matching name""" + return isinstance(node.func, ast.Attribute) and node.func.attr == name + + +def _matching_namespace(node, namespace): + """Determines if the node corresponds to a matching namespace""" + names = namespace.split(".") + name, value = names.pop(), node.func.value + while isinstance(value, ast.Attribute) and len(names) > 0: + if value.attr != name: + return False + name, value = names.pop(), value.value + + return isinstance(value, ast.Name) and value.id == name + + +def _rename_class(node): + """Renames the RealTimePredictor base class to Predictor""" + if _matching_name(node, BASE_PREDICTOR): + node.func.id = "Predictor" + elif _matching_attr(node, BASE_PREDICTOR): + node.func.attr = "Predictor" + + +def _rename_endpoint(node): + """Renames keyword endpoint argument to endpoint_name""" + for keyword in node.keywords: + if keyword.arg == "endpoint": + keyword.arg = "endpoint_name" + break + + +class PredictorImportFromRenamer(Modifier): + """A class to update import statements of ``RealTimePredictor``.""" + + def node_should_be_modified(self, node): + """Checks if the import statement imports ``RealTimePredictor`` from the correct module. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the import statement imports ``RealTimePredictor`` from the correct module. + """ + return node.module in PREDICTORS[BASE_PREDICTOR] and any( + name.name == BASE_PREDICTOR for name in node.names + ) + + def modify_node(self, node): + """Changes the ``ast.ImportFrom`` node's name from ``RealTimePredictor`` to ``Predictor``. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + """ + for name in node.names: + if name.name == BASE_PREDICTOR: + name.name = "Predictor" diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_predictors.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_predictors.py new file mode 100644 index 0000000000..6a7d518ad6 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_predictors.py @@ -0,0 +1,128 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pasta +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import predictors +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import + + +@pytest.fixture +def base_constructors(): + return ( + "sagemaker.predictor.RealTimePredictor(endpoint='a')", + "sagemaker.RealTimePredictor(endpoint='b')", + "RealTimePredictor(endpoint='c')", + ) + + +@pytest.fixture +def sparkml_constructors(): + return ( + "sagemaker.sparkml.model.SparkMLPredictor(endpoint='a')", + "sagemaker.sparkml.SparkMLPredictor(endpoint='b')", + "SparkMLPredictor(endpoint='c')", + ) + + +@pytest.fixture +def other_constructors(): + return ( + "sagemaker.amazon.knn.KNNPredictor(endpoint='a')", + "sagemaker.KNNPredictor(endpoint='b')", + "KNNPredictor(endpoint='c')", + ) + + +@pytest.fixture +def import_statements(): + return ( + "from sagemaker.predictor import RealTimePredictor", + "from sagemaker import RealTimePredictor", + ) + + +def test_constructor_node_should_be_modified_base(base_constructors): + modifier = predictors.PredictorConstructorRefactor() + for constructor in base_constructors: + node = ast_call(constructor) + assert modifier.node_should_be_modified(node) + + +def test_constructor_node_should_be_modified_sparkml(sparkml_constructors): + modifier = predictors.PredictorConstructorRefactor() + for constructor in sparkml_constructors: + node = ast_call(constructor) + assert modifier.node_should_be_modified(node) + + +def test_constructor_node_should_be_modified_other(other_constructors): + modifier = predictors.PredictorConstructorRefactor() + for constructor in other_constructors: + node = ast_call(constructor) + assert modifier.node_should_be_modified(node) + + +def test_constructor_node_should_be_modified_random_call(): + modifier = predictors.PredictorConstructorRefactor() + node = ast_call("Model()") + assert not modifier.node_should_be_modified(node) + + +def test_constructor_modify_node(): + modifier = predictors.PredictorConstructorRefactor() + + node = ast_call("sagemaker.RealTimePredictor(endpoint='a')") + modifier.modify_node(node) + assert "sagemaker.Predictor(endpoint_name='a')" == pasta.dump(node) + + node = ast_call("RealTimePredictor(endpoint='a')") + modifier.modify_node(node) + assert "Predictor(endpoint_name='a')" == pasta.dump(node) + + node = ast_call("sagemaker.amazon.kmeans.KMeansPredictor(endpoint='a')") + modifier.modify_node(node) + assert "sagemaker.amazon.kmeans.KMeansPredictor(endpoint_name='a')" == pasta.dump(node) + + node = ast_call("KMeansPredictor(endpoint='a')") + modifier.modify_node(node) + assert "KMeansPredictor(endpoint_name='a')" == pasta.dump(node) + + +def test_import_from_node_should_be_modified_predictor_module(import_statements): + modifier = predictors.PredictorImportFromRenamer() + for statement in import_statements: + node = ast_import(statement) + assert modifier.node_should_be_modified(node) + + +def test_import_from_node_should_be_modified_random_import(): + modifier = predictors.PredictorImportFromRenamer() + node = ast_import("from sagemaker import Session") + assert not modifier.node_should_be_modified(node) + + +def test_import_from_modify_node(): + modifier = predictors.PredictorImportFromRenamer() + + node = ast_import("from sagemaker.predictor import BytesDeserializer, RealTimePredictor") + modifier.modify_node(node) + expected_result = "from sagemaker.predictor import BytesDeserializer, Predictor" + assert expected_result == pasta.dump(node) + + node = ast_import("from sagemaker.predictor import RealTimePredictor as RTP") + modifier.modify_node(node) + expected_result = "from sagemaker.predictor import Predictor as RTP" + assert expected_result == pasta.dump(node)