Skip to content

Commit

Permalink
Add python e2e tests (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
krinart authored and izeigerman committed Jan 25, 2019
1 parent fbdee6d commit 00abbff
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 17 deletions.
6 changes: 3 additions & 3 deletions m2cgen/interpreters/python/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class PythonCodeGenerator(BaseCodeGenerator):
tpl_var_declaration = CT("")
tpl_block_termination = CT("")

def add_method_def(self, name, args):
def add_function_def(self, name, args):
method_def = "def " + " " + name + "("
method_def += ", ".join(args)
method_def += "):"
self.add_code_line(method_def)
self.increase_indent()

@contextlib.contextmanager
def method_definition(self, name, args):
self.add_method_def(name, args)
def function_definition(self, name, args):
self.add_function_def(name, args)
yield
2 changes: 1 addition & 1 deletion m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, indent=4, *args, **kwargs):
def interpret(self, expr):
self._cg.reset_state()

with self._cg.method_definition(
with self._cg.function_definition(
name="score",
args=[self._feature_array_name]):
last_result = self._do_interpret(expr)
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/executors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from tests.e2e.executors.java import JavaExecutor
from tests.e2e.executors.python import PythonExecutor

__all__ = [
JavaExecutor,
PythonExecutor,
]
22 changes: 22 additions & 0 deletions tests/e2e/executors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import contextlib

from tests import utils


class BaseExecutor:

_resource_tmp_dir = None

@contextlib.contextmanager
def prepare_then_cleanup(self):
with utils.tmp_dir() as tmp_dirpath:
self._resource_tmp_dir = tmp_dirpath
self.prepare()

try:
yield
finally:
self._resource_tmp_dir = None

def prepare(self):
raise NotImplementedError
11 changes: 3 additions & 8 deletions tests/e2e/executors/java.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
import os
import tempfile
import subprocess
import shutil

from m2cgen import exporters
from tests.e2e.executors import base


class JavaExecutor:
class JavaExecutor(base.BaseExecutor):

def __init__(self, model):
self.model = model
self.exporter = exporters.JavaExporter(model)
self._is_compiled = False
self._resource_tmp_dir = tempfile.mkdtemp()

java_home = os.environ.get("JAVA_HOME")
assert java_home, "JAVA_HOME is not specified"
self._java_bin = os.path.join(java_home, "bin/java")
self._javac_bin = os.path.join(java_home, "bin/javac")

def predict(self, X):
if not self._is_compiled:
self._compile()
self._is_compiled = True

exec_args = [
self._java_bin, "-cp", self._resource_tmp_dir,
Expand All @@ -33,7 +28,7 @@ def predict(self, X):

return float(result.stdout)

def _compile(self):
def prepare(self):
# Create files generated by exporter in the temp dir.
files_to_compile = []
for model_name, code in self.exporter.export():
Expand Down
39 changes: 39 additions & 0 deletions tests/e2e/executors/python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import importlib
import os
import sys

from m2cgen import exporters
from tests.e2e.executors import base


class PythonExecutor(base.BaseExecutor):

def __init__(self, model):
self.model = model
self.exporter = exporters.PythonExporter(model)

def predict(self, X):
# Hacky way to dynamically import generated function

parent_dir = os.path.dirname(self._resource_tmp_dir)
package = os.path.basename(self._resource_tmp_dir)

sys.path.append(parent_dir)

try:
score = importlib.import_module("{}.model".format(package)).score
finally:
sys.path.pop()

return score(X)

def prepare(self):
exported_models = self.exporter.export()
assert len(exported_models) == 1

_, code = exported_models[0]

file_name = os.path.join(self._resource_tmp_dir, "model.py")

with open(file_name, "w") as f:
f.write(code)
26 changes: 21 additions & 5 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def exec_e2e_test(estimator, executor_cls):
X_test, y_pred_true = utils.train_model(estimator)
executor = executor_cls(estimator)

for idx in range(len(X_test)):
y_pred_executed = executor.predict(X_test[idx])
print("expected={}, actual={}".format(y_pred_true[idx],
y_pred_executed))
assert np.isclose(y_pred_true[idx], y_pred_executed)
with executor.prepare_then_cleanup():
for idx in range(len(X_test)):
y_pred_executed = executor.predict(X_test[idx])
print("expected={}, actual={}".format(y_pred_true[idx],
y_pred_executed))
assert np.isclose(y_pred_true[idx], y_pred_executed)


def test_java_linear():
Expand All @@ -37,3 +38,18 @@ def test_java_ensemble():
estimator = ensemble.RandomForestRegressor(n_estimators=10,
random_state=RANDOM_SEED)
exec_e2e_test(estimator, executors.JavaExecutor)


def test_python_linear():
estimator = linear_model.LinearRegression()
exec_e2e_test(estimator, executors.PythonExecutor)


def test_python_tree():
estimator = tree.DecisionTreeRegressor()
exec_e2e_test(estimator, executors.PythonExecutor)


def test_python_ensemble():
estimator = ensemble.RandomForestRegressor(n_estimators=10)
exec_e2e_test(estimator, executors.PythonExecutor)
14 changes: 14 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import contextlib
import shutil
import tempfile

import numpy as np
from sklearn.datasets import load_boston
from sklearn.utils import shuffle
Expand Down Expand Up @@ -49,3 +53,13 @@ def train_model(estimator, test_fraction=0.1):
y_pred = estimator.predict(X_test)

return X_test, y_pred


@contextlib.contextmanager
def tmp_dir():
dirpath = tempfile.mkdtemp()

try:
yield dirpath
finally:
shutil.rmtree(dirpath)

0 comments on commit 00abbff

Please sign in to comment.