Skip to content

Commit

Permalink
Add tests for python interpreter/code_generator (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
krinart authored and izeigerman committed Jan 24, 2019
1 parent 80dfb48 commit fbdee6d
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 24 deletions.
6 changes: 2 additions & 4 deletions m2cgen/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def __init__(self, model, package_name=None, model_name="Model", indent=4):

class PythonExporter(BaseExporter):

def __init__(self, model, model_name="Model", indent=4):
self.interpreter = interpreters.PythonInterpreter(
model_name=model_name,
indent=indent)
def __init__(self, model, indent=4):
self.interpreter = interpreters.PythonInterpreter(indent=indent)
super(PythonExporter, self).__init__(model)
2 changes: 2 additions & 0 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def interpret_feature_ref(self, expr):
array_name=self._feature_array_name,
index=expr.index)

# Private methods implementing visitor pattern

def _do_interpret(self, expr, **kwargs):
try:
handler = self._select_handler(expr)
Expand Down
12 changes: 1 addition & 11 deletions m2cgen/interpreters/python/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,13 @@ class PythonCodeGenerator(BaseCodeGenerator):
tpl_var_declaration = CT("")
tpl_block_termination = CT("")

def add_class_def(self, class_name):
class_def = "class " + class_name + "(object):"
self.add_code_line(class_def)
self.increase_indent()

def add_method_def(self, name, args):
method_def = "def " + " " + name + "(self, "
method_def = "def " + " " + name + "("
method_def += ", ".join(args)
method_def += "):"
self.add_code_line(method_def)
self.increase_indent()

@contextlib.contextmanager
def class_definition(self, model_name):
self.add_class_def(model_name)
yield

@contextlib.contextmanager
def method_definition(self, name, args):
self.add_method_def(name, args)
Expand Down
16 changes: 7 additions & 9 deletions m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@

class PythonInterpreter(BaseInterpreter):

def __init__(self, model_name="Model", indent=4, *args, **kwargs):
self.model_name = model_name
def __init__(self, indent=4, *args, **kwargs):
cg = PythonCodeGenerator(indent=indent)
super(PythonInterpreter, self).__init__(cg, *args, **kwargs)

def interpret(self, expr):
self._cg.reset_state()

with self._cg.class_definition(self.model_name):
with self._cg.method_definition(
name="score",
args=[self._feature_array_name]):
last_result = self._do_interpret(expr)
self._cg.add_return_statement(last_result)
with self._cg.method_definition(
name="score",
args=[self._feature_array_name]):
last_result = self._do_interpret(expr)
self._cg.add_return_statement(last_result)

return [
(self.model_name, self._cg.code),
("", self._cg.code),
]
Empty file added tests/interpreters/__init__.py
Empty file.
115 changes: 115 additions & 0 deletions tests/interpreters/test_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from m2cgen import ast
from m2cgen import interpreters
from tests import utils


def test_if_expr():
expr = ast.IfExpr(
ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
ast.NumVal(2),
ast.NumVal(3))

interpreter = interpreters.PythonInterpreter()

expected_code = """
def score(input):
if (1) == (input[0]):
var0 = 2
else:
var0 = 3
return var0
"""

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)


def test_bin_num_expr():
expr = ast.BinNumExpr(
ast.BinNumExpr(
ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV),
ast.NumVal(2),
ast.BinNumOpType.MUL)

interpreter = interpreters.PythonInterpreter()

expected_code = """
def score(input):
return ((input[0]) / (-2)) * (2)
"""

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)


def test_dependable_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)

right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)

expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))

expected_code = """
def score(input):
if (1) == (1):
var1 = 1
else:
var1 = 2
if ((var1) + (2)) >= ((1) / (2)):
var0 = 1
else:
var0 = input[0]
return var0
"""

interpreter = interpreters.PythonInterpreter()

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)


def test_nested_condition():
left = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.NumVal(1),
ast.NumVal(2)),
ast.NumVal(2),
ast.BinNumOpType.ADD)

bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)

expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))

expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))

expected_code = """
def score(input):
if (1) == (1):
var1 = 1
else:
var1 = 2
if (1) == ((var1) + (2)):
if (1) == (1):
var2 = 1
else:
var2 = 2
if (1) == ((var2) + (2)):
var0 = input[2]
else:
var0 = 2
else:
var0 = 2
return var0
"""

interpreter = interpreters.PythonInterpreter()
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
4 changes: 4 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def cmp_exprs(left, right):
return False


def assert_code_equal(actual, expected):
assert actual.strip() == expected.strip()


def train_model(estimator, test_fraction=0.1):
boston = load_boston()

Expand Down

0 comments on commit fbdee6d

Please sign in to comment.