diff --git a/m2cgen/interpreters/python/code_generator.py b/m2cgen/interpreters/python/code_generator.py index c81eb79c..28e65458 100644 --- a/m2cgen/interpreters/python/code_generator.py +++ b/m2cgen/interpreters/python/code_generator.py @@ -17,7 +17,7 @@ class PythonCodeGenerator(BaseCodeGenerator): tpl_var_declaration = CT("") tpl_block_termination = CT("") - def add_method_def(self, name, args): + def add_function_def(self, name, args): method_def = "def " + " " + name + "(" method_def += ", ".join(args) method_def += "):" @@ -25,6 +25,6 @@ def add_method_def(self, name, args): self.increase_indent() @contextlib.contextmanager - def method_definition(self, name, args): - self.add_method_def(name, args) + def function_definition(self, name, args): + self.add_function_def(name, args) yield diff --git a/m2cgen/interpreters/python/interpreter.py b/m2cgen/interpreters/python/interpreter.py index d5fb8c7f..56f025b3 100644 --- a/m2cgen/interpreters/python/interpreter.py +++ b/m2cgen/interpreters/python/interpreter.py @@ -11,7 +11,7 @@ def __init__(self, indent=4, *args, **kwargs): def interpret(self, expr): self._cg.reset_state() - with self._cg.method_definition( + with self._cg.function_definition( name="score", args=[self._feature_array_name]): last_result = self._do_interpret(expr) diff --git a/tests/e2e/executors/__init__.py b/tests/e2e/executors/__init__.py index c0ce0e07..de8d7982 100644 --- a/tests/e2e/executors/__init__.py +++ b/tests/e2e/executors/__init__.py @@ -1,5 +1,7 @@ from tests.e2e.executors.java import JavaExecutor +from tests.e2e.executors.python import PythonExecutor __all__ = [ JavaExecutor, + PythonExecutor, ] diff --git a/tests/e2e/executors/base.py b/tests/e2e/executors/base.py new file mode 100644 index 00000000..ce7f7404 --- /dev/null +++ b/tests/e2e/executors/base.py @@ -0,0 +1,22 @@ +import contextlib + +from tests import utils + + +class BaseExecutor: + + _resource_tmp_dir = None + + @contextlib.contextmanager + def prepare_then_cleanup(self): + with utils.tmp_dir() as tmp_dirpath: + self._resource_tmp_dir = tmp_dirpath + self.prepare() + + try: + yield + finally: + self._resource_tmp_dir = None + + def prepare(self): + raise NotImplementedError diff --git a/tests/e2e/executors/java.py b/tests/e2e/executors/java.py index 95bc2ae2..a3df4f5c 100644 --- a/tests/e2e/executors/java.py +++ b/tests/e2e/executors/java.py @@ -1,18 +1,16 @@ import os -import tempfile import subprocess import shutil from m2cgen import exporters +from tests.e2e.executors import base -class JavaExecutor: +class JavaExecutor(base.BaseExecutor): def __init__(self, model): self.model = model self.exporter = exporters.JavaExporter(model) - self._is_compiled = False - self._resource_tmp_dir = tempfile.mkdtemp() java_home = os.environ.get("JAVA_HOME") assert java_home, "JAVA_HOME is not specified" @@ -20,9 +18,6 @@ def __init__(self, model): self._javac_bin = os.path.join(java_home, "bin/javac") def predict(self, X): - if not self._is_compiled: - self._compile() - self._is_compiled = True exec_args = [ self._java_bin, "-cp", self._resource_tmp_dir, @@ -33,7 +28,7 @@ def predict(self, X): return float(result.stdout) - def _compile(self): + def prepare(self): # Create files generated by exporter in the temp dir. files_to_compile = [] for model_name, code in self.exporter.export(): diff --git a/tests/e2e/executors/python.py b/tests/e2e/executors/python.py new file mode 100644 index 00000000..9c1fc445 --- /dev/null +++ b/tests/e2e/executors/python.py @@ -0,0 +1,39 @@ +import importlib +import os +import sys + +from m2cgen import exporters +from tests.e2e.executors import base + + +class PythonExecutor(base.BaseExecutor): + + def __init__(self, model): + self.model = model + self.exporter = exporters.PythonExporter(model) + + def predict(self, X): + # Hacky way to dynamically import generated function + + parent_dir = os.path.dirname(self._resource_tmp_dir) + package = os.path.basename(self._resource_tmp_dir) + + sys.path.append(parent_dir) + + try: + score = importlib.import_module("{}.model".format(package)).score + finally: + sys.path.pop() + + return score(X) + + def prepare(self): + exported_models = self.exporter.export() + assert len(exported_models) == 1 + + _, code = exported_models[0] + + file_name = os.path.join(self._resource_tmp_dir, "model.py") + + with open(file_name, "w") as f: + f.write(code) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 5e1042af..b5119b3c 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -15,11 +15,12 @@ def exec_e2e_test(estimator, executor_cls): X_test, y_pred_true = utils.train_model(estimator) executor = executor_cls(estimator) - for idx in range(len(X_test)): - y_pred_executed = executor.predict(X_test[idx]) - print("expected={}, actual={}".format(y_pred_true[idx], - y_pred_executed)) - assert np.isclose(y_pred_true[idx], y_pred_executed) + with executor.prepare_then_cleanup(): + for idx in range(len(X_test)): + y_pred_executed = executor.predict(X_test[idx]) + print("expected={}, actual={}".format(y_pred_true[idx], + y_pred_executed)) + assert np.isclose(y_pred_true[idx], y_pred_executed) def test_java_linear(): @@ -37,3 +38,18 @@ def test_java_ensemble(): estimator = ensemble.RandomForestRegressor(n_estimators=10, random_state=RANDOM_SEED) exec_e2e_test(estimator, executors.JavaExecutor) + + +def test_python_linear(): + estimator = linear_model.LinearRegression() + exec_e2e_test(estimator, executors.PythonExecutor) + + +def test_python_tree(): + estimator = tree.DecisionTreeRegressor() + exec_e2e_test(estimator, executors.PythonExecutor) + + +def test_python_ensemble(): + estimator = ensemble.RandomForestRegressor(n_estimators=10) + exec_e2e_test(estimator, executors.PythonExecutor) diff --git a/tests/utils.py b/tests/utils.py index c62128e2..4f9da376 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,7 @@ +import contextlib +import shutil +import tempfile + import numpy as np from sklearn.datasets import load_boston from sklearn.utils import shuffle @@ -49,3 +53,13 @@ def train_model(estimator, test_fraction=0.1): y_pred = estimator.predict(X_test) return X_test, y_pred + + +@contextlib.contextmanager +def tmp_dir(): + dirpath = tempfile.mkdtemp() + + try: + yield dirpath + finally: + shutil.rmtree(dirpath)