From 91ef2ef6abd140bae7921767a7ebef1d52cd2b8f Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Fri, 11 Sep 2020 21:21:07 +0300 Subject: [PATCH] Add cauchy link function support to the linear model assembler (#304) * added atan expression * added atan for F# * fix lint * added cauchy link function * hotfix --- MANIFEST.in | 1 + m2cgen/assemblers/fallback_expressions.py | 122 +++++++++++++++++- m2cgen/assemblers/linear.py | 12 +- m2cgen/ast.py | 20 ++- m2cgen/interpreters/c/interpreter.py | 1 + m2cgen/interpreters/c_sharp/interpreter.py | 1 + m2cgen/interpreters/dart/interpreter.py | 1 + m2cgen/interpreters/f_sharp/interpreter.py | 1 + m2cgen/interpreters/go/interpreter.py | 1 + m2cgen/interpreters/haskell/interpreter.py | 1 + m2cgen/interpreters/interpreter.py | 14 +- m2cgen/interpreters/java/interpreter.py | 1 + m2cgen/interpreters/javascript/interpreter.py | 1 + m2cgen/interpreters/php/interpreter.py | 1 + m2cgen/interpreters/powershell/interpreter.py | 6 + m2cgen/interpreters/python/interpreter.py | 1 + m2cgen/interpreters/r/interpreter.py | 1 + m2cgen/interpreters/ruby/interpreter.py | 1 + m2cgen/interpreters/visual_basic/atan.bas | 44 +++++++ .../interpreters/visual_basic/interpreter.py | 11 ++ tests/assemblers/test_linear.py | 30 +++++ tests/e2e/test_e2e.py | 6 + tests/interpreters/test_c.py | 14 ++ tests/interpreters/test_c_sharp.py | 18 +++ tests/interpreters/test_dart.py | 13 ++ tests/interpreters/test_f_sharp.py | 12 ++ tests/interpreters/test_go.py | 14 ++ tests/interpreters/test_haskell.py | 14 ++ tests/interpreters/test_java.py | 15 +++ tests/interpreters/test_javascript.py | 14 ++ tests/interpreters/test_php.py | 14 ++ tests/interpreters/test_powershell.py | 13 ++ tests/interpreters/test_python.py | 14 ++ tests/interpreters/test_r.py | 13 ++ tests/interpreters/test_ruby.py | 13 ++ tests/interpreters/test_visual_basic.py | 59 +++++++++ tests/test_ast.py | 7 +- tests/test_fallback_expressions.py | 47 +++++++ 38 files changed, 560 insertions(+), 12 deletions(-) create mode 100644 m2cgen/interpreters/visual_basic/atan.bas diff --git a/MANIFEST.in b/MANIFEST.in index 90f00e63..6a9fa8a4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,4 +3,5 @@ recursive-include m2cgen VERSION.txt recursive-include m2cgen linear_algebra.* recursive-include m2cgen log1p.* recursive-include m2cgen tanh.* +recursive-include m2cgen atan.* global-exclude *.py[cod] diff --git a/m2cgen/assemblers/fallback_expressions.py b/m2cgen/assemblers/fallback_expressions.py index cd357a02..e42329ed 100644 --- a/m2cgen/assemblers/fallback_expressions.py +++ b/m2cgen/assemblers/fallback_expressions.py @@ -40,18 +40,16 @@ def tanh(expr): tanh_expr)) -def sqrt(expr, to_reuse=False): +def sqrt(expr): return ast.PowExpr( base_expr=expr, - exp_expr=ast.NumVal(0.5), - to_reuse=to_reuse) + exp_expr=ast.NumVal(0.5)) -def exp(expr, to_reuse=False): +def exp(expr): return ast.PowExpr( base_expr=ast.NumVal(math.e), - exp_expr=expr, - to_reuse=to_reuse) + exp_expr=expr) def log1p(expr): @@ -66,6 +64,118 @@ def log1p(expr): utils.div(utils.mul(expr, ast.LogExpr(expr1p)), expr1pm1)) +def atan(expr): + expr = ast.IdExpr(expr, to_reuse=True) + expr_abs = ast.AbsExpr(expr, to_reuse=True) + + expr_reduced = ast.IdExpr( + ast.IfExpr( + utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)), + utils.div(ast.NumVal(1.0), expr_abs), + ast.IfExpr( + utils.gt(expr_abs, ast.NumVal(0.66)), + utils.div( + utils.sub(expr_abs, ast.NumVal(1.0)), + utils.add(expr_abs, ast.NumVal(1.0))), + expr_abs)), + to_reuse=True) + + P0 = ast.NumVal(-8.750608600031904122785e-01) + P1 = ast.NumVal(1.615753718733365076637e+01) + P2 = ast.NumVal(7.500855792314704667340e+01) + P3 = ast.NumVal(1.228866684490136173410e+02) + P4 = ast.NumVal(6.485021904942025371773e+01) + Q0 = ast.NumVal(2.485846490142306297962e+01) + Q1 = ast.NumVal(1.650270098316988542046e+02) + Q2 = ast.NumVal(4.328810604912902668951e+02) + Q3 = ast.NumVal(4.853903996359136964868e+02) + Q4 = ast.NumVal(1.945506571482613964425e+02) + expr2 = utils.mul(expr_reduced, expr_reduced, to_reuse=True) + z = utils.mul( + expr2, + utils.div( + utils.sub( + utils.mul( + expr2, + utils.sub( + utils.mul( + expr2, + utils.sub( + utils.mul( + expr2, + utils.sub( + utils.mul( + expr2, + P0 + ), + P1 + ) + ), + P2 + ) + ), + P3 + ) + ), + P4 + ), + utils.add( + Q4, + utils.mul( + expr2, + utils.add( + Q3, + utils.mul( + expr2, + utils.add( + Q2, + utils.mul( + expr2, + utils.add( + Q1, + utils.mul( + expr2, + utils.add( + Q0, + expr2 + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + z = utils.add(utils.mul(expr_reduced, z), expr_reduced) + + ret = utils.mul( + z, + ast.IfExpr( + utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)), + ast.NumVal(-1.0), + ast.NumVal(1.0))) + ret = utils.add( + ret, + ast.IfExpr( + utils.lte(expr_abs, ast.NumVal(0.66)), + ast.NumVal(0.0), + ast.IfExpr( + utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)), + ast.NumVal(1.570796326794896680463661649), + ast.NumVal(0.7853981633974483402318308245)))) + ret = utils.mul( + ret, + ast.IfExpr( + utils.lt(expr, ast.NumVal(0.0)), + ast.NumVal(-1.0), + ast.NumVal(1.0))) + + return ret + + def sigmoid(expr, to_reuse=False): neg_expr = ast.BinNumExpr(ast.NumVal(0.0), expr, ast.BinNumOpType.SUB) exp_expr = ast.ExpExpr(neg_expr) diff --git a/m2cgen/assemblers/linear.py b/m2cgen/assemblers/linear.py index 90b244a5..67c5caf7 100644 --- a/m2cgen/assemblers/linear.py +++ b/m2cgen/assemblers/linear.py @@ -1,3 +1,5 @@ +import math + import numpy as np from m2cgen import ast @@ -149,6 +151,13 @@ def _negativebinomial_inversed(self, ast_to_transform): ast.NumVal(-1.0), utils.mul(ast.NumVal(alpha), res) if alpha != 1.0 else res) + def _cauchy_inversed(self, ast_to_transform): + return utils.add( + ast.NumVal(0.5), + utils.div( + ast.AtanExpr(ast_to_transform), + ast.NumVal(math.pi))) + def _get_power(self): raise NotImplementedError @@ -172,7 +181,8 @@ def _get_supported_inversed_funs(self): "log": self._log_inversed, "cloglog": self._cloglog_inversed, "negativebinomial": self._negativebinomial_inversed, - "nbinom": self._negativebinomial_inversed + "nbinom": self._negativebinomial_inversed, + "cauchy": self._cauchy_inversed } def _get_power(self): diff --git a/m2cgen/ast.py b/m2cgen/ast.py index e467c343..4c167677 100644 --- a/m2cgen/ast.py +++ b/m2cgen/ast.py @@ -86,6 +86,23 @@ def __hash__(self): return hash(self.expr) +class AtanExpr(NumExpr): + def __init__(self, expr, to_reuse=False): + assert expr.output_size == 1, "Only scalars are supported" + + self.expr = expr + self.to_reuse = to_reuse + + def __str__(self): + return f"AtanExpr({self.expr},to_reuse={self.to_reuse})" + + def __eq__(self, other): + return type(other) is AtanExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + + class ExpExpr(NumExpr): def __init__(self, expr, to_reuse=False): assert expr.output_size == 1, "Only scalars are supported" @@ -370,7 +387,8 @@ def __hash__(self): (PowExpr, lambda e: [e.base_expr, e.exp_expr]), (VectorVal, lambda e: e.exprs), (IfExpr, lambda e: [e.test, e.body, e.orelse]), - ((AbsExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, SqrtExpr, TanhExpr), + ((AbsExpr, AtanExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, + SqrtExpr, TanhExpr), lambda e: [e.expr]), ] diff --git a/m2cgen/interpreters/c/interpreter.py b/m2cgen/interpreters/c/interpreter.py index 2af076e2..a8e55c8a 100644 --- a/m2cgen/interpreters/c/interpreter.py +++ b/m2cgen/interpreters/c/interpreter.py @@ -18,6 +18,7 @@ class CInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "fabs" + atan_function_name = "atan" exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" diff --git a/m2cgen/interpreters/c_sharp/interpreter.py b/m2cgen/interpreters/c_sharp/interpreter.py index 4845c0a1..7a8b721c 100644 --- a/m2cgen/interpreters/c_sharp/interpreter.py +++ b/m2cgen/interpreters/c_sharp/interpreter.py @@ -19,6 +19,7 @@ class CSharpInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "Abs" + atan_function_name = "Atan" exponent_function_name = "Exp" logarithm_function_name = "Log" log1p_function_name = "Log1p" diff --git a/m2cgen/interpreters/dart/interpreter.py b/m2cgen/interpreters/dart/interpreter.py index 2553571c..1ed22067 100644 --- a/m2cgen/interpreters/dart/interpreter.py +++ b/m2cgen/interpreters/dart/interpreter.py @@ -22,6 +22,7 @@ class DartInterpreter(ImperativeToCodeInterpreter, bin_depth_threshold = 465 abs_function_name = "abs" + atan_function_name = "atan" exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" diff --git a/m2cgen/interpreters/f_sharp/interpreter.py b/m2cgen/interpreters/f_sharp/interpreter.py index e35ae1b5..89300078 100644 --- a/m2cgen/interpreters/f_sharp/interpreter.py +++ b/m2cgen/interpreters/f_sharp/interpreter.py @@ -26,6 +26,7 @@ class FSharpInterpreter(FunctionalToCodeInterpreter, } abs_function_name = "abs" + atan_function_name = "atan" exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" diff --git a/m2cgen/interpreters/go/interpreter.py b/m2cgen/interpreters/go/interpreter.py index 01620634..a325d4db 100644 --- a/m2cgen/interpreters/go/interpreter.py +++ b/m2cgen/interpreters/go/interpreter.py @@ -17,6 +17,7 @@ class GoInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "math.Abs" + atan_function_name = "math.Atan" exponent_function_name = "math.Exp" logarithm_function_name = "math.Log" log1p_function_name = "math.Log1p" diff --git a/m2cgen/interpreters/haskell/interpreter.py b/m2cgen/interpreters/haskell/interpreter.py index 1df162bf..6a435418 100644 --- a/m2cgen/interpreters/haskell/interpreter.py +++ b/m2cgen/interpreters/haskell/interpreter.py @@ -17,6 +17,7 @@ class HaskellInterpreter(FunctionalToCodeInterpreter, } abs_function_name = "abs" + atan_function_name = "atan" exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" diff --git a/m2cgen/interpreters/interpreter.py b/m2cgen/interpreters/interpreter.py index 51c5bc65..fd684673 100644 --- a/m2cgen/interpreters/interpreter.py +++ b/m2cgen/interpreters/interpreter.py @@ -82,6 +82,7 @@ class ToCodeInterpreter(BaseToCodeInterpreter): """ abs_function_name = NotImplemented + atan_function_name = NotImplemented exponent_function_name = NotImplemented logarithm_function_name = NotImplemented log1p_function_name = NotImplemented @@ -132,10 +133,19 @@ def interpret_abs_expr(self, expr, **kwargs): return self._cg.function_invocation( self.abs_function_name, nested_result) + def interpret_atan_expr(self, expr, **kwargs): + if self.atan_function_name is NotImplemented: + return self._do_interpret( + fallback_expressions.atan(expr.expr), **kwargs) + self.with_math_module = True + nested_result = self._do_interpret(expr.expr, **kwargs) + return self._cg.function_invocation( + self.atan_function_name, nested_result) + def interpret_exp_expr(self, expr, **kwargs): if self.exponent_function_name is NotImplemented: return self._do_interpret( - fallback_expressions.exp(expr.expr, to_reuse=expr.to_reuse), + fallback_expressions.exp(expr.expr), **kwargs) self.with_math_module = True nested_result = self._do_interpret(expr.expr, **kwargs) @@ -162,7 +172,7 @@ def interpret_log1p_expr(self, expr, **kwargs): def interpret_sqrt_expr(self, expr, **kwargs): if self.sqrt_function_name is NotImplemented: return self._do_interpret( - fallback_expressions.sqrt(expr.expr, to_reuse=expr.to_reuse), + fallback_expressions.sqrt(expr.expr), **kwargs) self.with_math_module = True nested_result = self._do_interpret(expr.expr, **kwargs) diff --git a/m2cgen/interpreters/java/interpreter.py b/m2cgen/interpreters/java/interpreter.py index 1951a42d..c67532dd 100644 --- a/m2cgen/interpreters/java/interpreter.py +++ b/m2cgen/interpreters/java/interpreter.py @@ -25,6 +25,7 @@ class JavaInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "Math.abs" + atan_function_name = "Math.atan" exponent_function_name = "Math.exp" logarithm_function_name = "Math.log" log1p_function_name = "Math.log1p" diff --git a/m2cgen/interpreters/javascript/interpreter.py b/m2cgen/interpreters/javascript/interpreter.py index 53d97937..4ac4afcd 100644 --- a/m2cgen/interpreters/javascript/interpreter.py +++ b/m2cgen/interpreters/javascript/interpreter.py @@ -20,6 +20,7 @@ class JavascriptInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "Math.abs" + atan_function_name = "Math.atan" exponent_function_name = "Math.exp" logarithm_function_name = "Math.log" log1p_function_name = "Math.log1p" diff --git a/m2cgen/interpreters/php/interpreter.py b/m2cgen/interpreters/php/interpreter.py index f4065d85..d25fa5a6 100644 --- a/m2cgen/interpreters/php/interpreter.py +++ b/m2cgen/interpreters/php/interpreter.py @@ -18,6 +18,7 @@ class PhpInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "abs" + atan_function_name = "atan" exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" diff --git a/m2cgen/interpreters/powershell/interpreter.py b/m2cgen/interpreters/powershell/interpreter.py index d2c9298b..4fb63dbc 100644 --- a/m2cgen/interpreters/powershell/interpreter.py +++ b/m2cgen/interpreters/powershell/interpreter.py @@ -19,6 +19,7 @@ class PowershellInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "[math]::Abs" + atan_function_name = "[math]::Atan" exponent_function_name = "[math]::Exp" logarithm_function_name = "[math]::Log" log1p_function_name = "Log1p" @@ -63,6 +64,11 @@ def interpret_abs_expr(self, expr, **kwargs): return self._cg.math_function_invocation( self.abs_function_name, nested_result) + def interpret_atan_expr(self, expr, **kwargs): + nested_result = self._do_interpret(expr.expr, **kwargs) + return self._cg.math_function_invocation( + self.atan_function_name, nested_result) + def interpret_exp_expr(self, expr, **kwargs): nested_result = self._do_interpret(expr.expr, **kwargs) return self._cg.math_function_invocation( diff --git a/m2cgen/interpreters/python/interpreter.py b/m2cgen/interpreters/python/interpreter.py index 859ab8d6..72c283e7 100644 --- a/m2cgen/interpreters/python/interpreter.py +++ b/m2cgen/interpreters/python/interpreter.py @@ -22,6 +22,7 @@ class PythonInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "abs" + atan_function_name = "math.atan" exponent_function_name = "math.exp" logarithm_function_name = "math.log" log1p_function_name = "math.log1p" diff --git a/m2cgen/interpreters/r/interpreter.py b/m2cgen/interpreters/r/interpreter.py index c82fdf2b..5c70626a 100644 --- a/m2cgen/interpreters/r/interpreter.py +++ b/m2cgen/interpreters/r/interpreter.py @@ -23,6 +23,7 @@ class RInterpreter(ImperativeToCodeInterpreter, ast_size_per_subroutine_threshold = 200 abs_function_name = "abs" + atan_function_name = "atan" exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" diff --git a/m2cgen/interpreters/ruby/interpreter.py b/m2cgen/interpreters/ruby/interpreter.py index e0dd3964..e5e7e786 100644 --- a/m2cgen/interpreters/ruby/interpreter.py +++ b/m2cgen/interpreters/ruby/interpreter.py @@ -18,6 +18,7 @@ class RubyInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "abs" + atan_function_name = "Math.atan" exponent_function_name = "Math.exp" logarithm_function_name = "Math.log" log1p_function_name = "log1p" diff --git a/m2cgen/interpreters/visual_basic/atan.bas b/m2cgen/interpreters/visual_basic/atan.bas new file mode 100644 index 00000000..ee708a62 --- /dev/null +++ b/m2cgen/interpreters/visual_basic/atan.bas @@ -0,0 +1,44 @@ +Function Xatan(ByVal x As Double) As Double + Dim z As Double + z = x * x + z = z * ((((-8.750608600031904122785e-01 * z _ + - 1.615753718733365076637e+01) * z _ + - 7.500855792314704667340e+01) * z _ + - 1.228866684490136173410e+02) * z _ + - 6.485021904942025371773e+01) _ + / (((((z + 2.485846490142306297962e+01) * z _ + + 1.650270098316988542046e+02) * z _ + + 4.328810604912902668951e+02) * z _ + + 4.853903996359136964868e+02) * z _ + + 1.945506571482613964425e+02) + Xatan = x * z + x +End Function +Function Satan(ByVal x As Double) As Double + Dim morebits as Double + Dim tan3pio8 as Double + morebits = 6.123233995736765886130e-17 + tan3pio8 = 2.41421356237309504880 + If x <= 0.66 Then + Satan = Xatan(x) + Exit Function + End If + If x > tan3pio8 Then + Satan = 1.57079632679489661923132169163 - Xatan(1.0 / x) + morebits + Exit Function + End If + Satan = 0.78539816339744830961566084581 + Xatan((x - 1) / (x + 1)) _ + + 3.061616997868382943065e-17 +End Function +Function Atan(ByVal number As Double) As Double + ' Implementation is taken from + ' https://github.com/golang/go/blob/master/src/math/atan.go + If number = 0.0 Then + Atan = 0.0 + Exit Function + End If + If number > 0.0 Then + Atan = Satan(number) + Exit Function + End If + Atan = -Satan(-number) +End Function diff --git a/m2cgen/interpreters/visual_basic/interpreter.py b/m2cgen/interpreters/visual_basic/interpreter.py index bf91c7da..6e32552f 100644 --- a/m2cgen/interpreters/visual_basic/interpreter.py +++ b/m2cgen/interpreters/visual_basic/interpreter.py @@ -18,11 +18,13 @@ class VisualBasicInterpreter(ImperativeToCodeInterpreter, } abs_function_name = "Math.Abs" + atan_function_name = "Atan" exponent_function_name = "Math.Exp" logarithm_function_name = "Math.Log" log1p_function_name = "Log1p" tanh_function_name = "Tanh" + with_atan_expr = False with_log1p_expr = False with_tanh_expr = False @@ -64,6 +66,11 @@ def interpret(self, expr): os.path.dirname(__file__), "log1p.bas") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_atan_expr: + filename = os.path.join( + os.path.dirname(__file__), "atan.bas") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + self._cg.prepend_code_line(self._cg.tpl_module_definition( module_name=self.module_name)) self._cg.add_code_line(self._cg.tpl_block_termination( @@ -84,3 +91,7 @@ def interpret_log1p_expr(self, expr, **kwargs): def interpret_tanh_expr(self, expr, **kwargs): self.with_tanh_expr = True return super().interpret_tanh_expr(expr, **kwargs) + + def interpret_atan_expr(self, expr, **kwargs): + self.with_atan_expr = True + return super().interpret_atan_expr(expr, **kwargs) diff --git a/tests/assemblers/test_linear.py b/tests/assemblers/test_linear.py index fbe3a290..23337c2d 100644 --- a/tests/assemblers/test_linear.py +++ b/tests/assemblers/test_linear.py @@ -667,6 +667,36 @@ def test_statsmodels_glm_negativebinomial_link_func(): assert utils.cmp_exprs(actual, expected) +def test_statsmodels_glm_cauchy_link_func(): + estimator = utils.StatsmodelsSklearnLikeWrapper( + sm.GLM, + dict(init=dict( + family=sm.families.Binomial( + sm.families.links.cauchy())), + fit=dict(maxiter=1))) + estimator = estimator.fit([[1], [2]], [0.1, 0.2]) + + assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator) + actual = assembler.assemble() + + expected = ast.BinNumExpr( + ast.NumVal(0.5), + ast.BinNumExpr( + ast.AtanExpr( + ast.BinNumExpr( + ast.NumVal(0.0), + ast.BinNumExpr( + ast.FeatureRef(0), + ast.NumVal(-0.7279996905393095), + ast.BinNumOpType.MUL), + ast.BinNumOpType.ADD)), + ast.NumVal(3.141592653589793), + ast.BinNumOpType.DIV), + ast.BinNumOpType.ADD) + + assert utils.cmp_exprs(actual, expected) + + @pytest.mark.xfail(raises=ValueError, strict=True) def test_statsmodels_glm_unknown_link_func(): diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 35aa6d2e..64762ae6 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -383,6 +383,12 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02): classification_binary(utils.StatsmodelsSklearnLikeWrapper( sm.GLM, dict(fit_regularized=STATSMODELS_LINEAR_REGULARIZED_PARAMS))), + classification_binary(utils.StatsmodelsSklearnLikeWrapper( + sm.GLM, + dict(init=dict( + family=sm.families.Binomial( + sm.families.links.cauchy())), + fit=dict(maxiter=2)))), classification_binary(utils.StatsmodelsSklearnLikeWrapper( sm.GLM, dict(init=dict( diff --git a/tests/interpreters/test_c.py b/tests/interpreters/test_c.py index 909f1288..39d283d6 100644 --- a/tests/interpreters/test_c.py +++ b/tests/interpreters/test_c.py @@ -310,6 +310,20 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + interpreter = interpreters.CInterpreter() + + expected_code = """ +#include +double score(double * input) { + return atan(2.0); +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_c_sharp.py b/tests/interpreters/test_c_sharp.py index 357f3e42..77873dac 100644 --- a/tests/interpreters/test_c_sharp.py +++ b/tests/interpreters/test_c_sharp.py @@ -471,6 +471,24 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + expected_code = """ +using static System.Math; +namespace ML { + public static class Model { + public static double Score(double[] input) { + return Atan(2.0); + } + } +} +""" + + interpreter = CSharpInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_dart.py b/tests/interpreters/test_dart.py index d776f7c1..f5643fbe 100644 --- a/tests/interpreters/test_dart.py +++ b/tests/interpreters/test_dart.py @@ -553,6 +553,19 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + expected_code = """ +import 'dart:math'; +double score(List input) { + return atan(2.0); +} +""" + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_f_sharp.py b/tests/interpreters/test_f_sharp.py index 7858abdf..b8f2c595 100644 --- a/tests/interpreters/test_f_sharp.py +++ b/tests/interpreters/test_f_sharp.py @@ -443,6 +443,18 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + expected_code = """ +let score (input : double list) = + atan (2.0) +""" + + interpreter = FSharpInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_go.py b/tests/interpreters/test_go.py index 4ba05f12..c5ad342e 100644 --- a/tests/interpreters/test_go.py +++ b/tests/interpreters/test_go.py @@ -313,6 +313,20 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + interpreter = interpreters.GoInterpreter() + + expected_code = """ +import "math" +func score(input []float64) float64 { + return math.Atan(2.0) +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_haskell.py b/tests/interpreters/test_haskell.py index cb71f703..490bb5d6 100644 --- a/tests/interpreters/test_haskell.py +++ b/tests/interpreters/test_haskell.py @@ -347,6 +347,20 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + expected_code = """ +module Model where +score :: [Double] -> Double +score input = + atan (2.0) +""" + + interpreter = HaskellInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_java.py b/tests/interpreters/test_java.py index 556aacd5..2ebf0afe 100644 --- a/tests/interpreters/test_java.py +++ b/tests/interpreters/test_java.py @@ -542,6 +542,21 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + interpreter = interpreters.JavaInterpreter() + + expected_code = """ +public class Model { + public static double score(double[] input) { + return Math.atan(2.0); + } +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_javascript.py b/tests/interpreters/test_javascript.py index d3adee89..5dfc3861 100644 --- a/tests/interpreters/test_javascript.py +++ b/tests/interpreters/test_javascript.py @@ -328,6 +328,20 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + interpreter = interpreters.JavascriptInterpreter() + + expected_code = """ +function score(input) { + return Math.atan(2.0); +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_php.py b/tests/interpreters/test_php.py index 737162fc..d9b00ed1 100644 --- a/tests/interpreters/test_php.py +++ b/tests/interpreters/test_php.py @@ -333,6 +333,20 @@ def test_log1p_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_atan_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + expected_code = """ + tan3pio8 Then + Satan = 1.57079632679489661923132169163 - Xatan(1.0 / x) + morebits + Exit Function + End If + Satan = 0.78539816339744830961566084581 + Xatan((x - 1) / (x + 1)) _ + + 3.061616997868382943065e-17 +End Function +Function Atan(ByVal number As Double) As Double + ' Implementation is taken from + ' https://github.com/golang/go/blob/master/src/math/atan.go + If number = 0.0 Then + Atan = 0.0 + Exit Function + End If + If number > 0.0 Then + Atan = Satan(number) + Exit Function + End If + Atan = -Satan(-number) +End Function +Function Score(ByRef inputVector() As Double) As Double + Score = Atan(2.0) +End Function +End Module +""" + + interpreter = VisualBasicInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/test_ast.py b/tests/test_ast.py index d2cb72b9..cf37911a 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -47,6 +47,7 @@ def test_count_exprs_exclude_list(): ast.BinVectorExpr( ast.VectorVal([ ast.AbsExpr(ast.NumVal(-2)), + ast.AtanExpr(ast.NumVal(2)), ast.ExpExpr(ast.NumVal(2)), ast.LogExpr(ast.NumVal(2)), ast.Log1pExpr(ast.NumVal(2)), @@ -67,6 +68,7 @@ def test_count_exprs_exclude_list(): ast.NumVal(5), ast.NumVal(6), ast.NumVal(7), + ast.NumVal(8), ast.FeatureRef(1) ])), ast.BinNumOpType.SUB), @@ -79,7 +81,7 @@ def test_count_exprs_exclude_list(): def test_count_all_exprs_types(): - assert ast.count_exprs(EXPR_WITH_ALL_EXPRS) == 37 + assert ast.count_exprs(EXPR_WITH_ALL_EXPRS) == 40 def test_exprs_equality(): @@ -96,6 +98,7 @@ def test_exprs_str(): assert str(EXPR_WITH_ALL_EXPRS) == """ BinVectorNumExpr(BinVectorExpr(VectorVal([ AbsExpr(NumVal(-2.0),to_reuse=False), +AtanExpr(NumVal(2.0),to_reuse=False), ExpExpr(NumVal(2.0),to_reuse=False), LogExpr(NumVal(2.0),to_reuse=False), Log1pExpr(NumVal(2.0),to_reuse=False), @@ -105,7 +108,7 @@ def test_exprs_str(): BinNumExpr(NumVal(0.0),FeatureRef(0),ADD,to_reuse=False)]), IdExpr(VectorVal([ NumVal(1.0),NumVal(2.0),NumVal(3.0),NumVal(4.0),NumVal(5.0), -NumVal(6.0),NumVal(7.0),FeatureRef(1)]),to_reuse=False),SUB), +NumVal(6.0),NumVal(7.0),NumVal(8.0),FeatureRef(1)]),to_reuse=False),SUB), IfExpr(CompExpr(NumVal(2.0),NumVal(0.0),GT),NumVal(3.0),NumVal(4.0)),MUL) """.strip().replace("\n", "") diff --git a/tests/test_fallback_expressions.py b/tests/test_fallback_expressions.py index 1163fdfb..8b58c85b 100644 --- a/tests/test_fallback_expressions.py +++ b/tests/test_fallback_expressions.py @@ -100,3 +100,50 @@ def score(input): """ assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_atan_fallback_expr(): + expr = ast.AtanExpr(ast.NumVal(2.0)) + + interpreter = PythonInterpreter() + interpreter.atan_function_name = NotImplemented + + expected_code = ( + """ +def score(input): + var1 = 2.0 + var2 = abs(var1) + if (var2) > (2.414213562373095): + var0 = (1.0) / (var2) + else: + if (var2) > (0.66): + var0 = ((var2) - (1.0)) / ((var2) + (1.0)) + else: + var0 = var2 + var3 = var0 + var4 = (var3) * (var3) + if (var2) > (2.414213562373095): + var5 = -1.0 + else: + var5 = 1.0 + if (var2) <= (0.66): + var6 = 0.0 + else: + if (var2) > (2.414213562373095): + var6 = 1.5707963267948968 + else: + var6 = 0.7853981633974484 + if (var1) < (0.0): + var7 = -1.0 + else: + var7 = 1.0 + return (((((var3) * ((var4) * ((((var4) * (((var4) * (((var4) * """ + """(((var4) * (-0.8750608600031904)) - (16.157537187333652))) - """ + """(75.00855792314705))) - (122.88666844901361))) - """ + """(64.85021904942025)) / ((194.5506571482614) + ((var4) * """ + """((485.3903996359137) + ((var4) * ((432.88106049129027) + """ + """((var4) * ((165.02700983169885) + ((var4) * """ + """((24.858464901423062) + (var4))))))))))))) + (var3)) * """ + """(var5)) + (var6)) * (var7)""") + + assert_code_equal(interpreter.interpret(expr), expected_code)