Skip to content

Commit

Permalink
Add cauchy link function support to the linear model assembler (#304)
Browse files Browse the repository at this point in the history
* added atan expression

* added atan for F#

* fix lint

* added cauchy link function

* hotfix
  • Loading branch information
StrikerRUS authored Sep 11, 2020
1 parent 0c1ea76 commit 91ef2ef
Show file tree
Hide file tree
Showing 38 changed files with 560 additions and 12 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ recursive-include m2cgen VERSION.txt
recursive-include m2cgen linear_algebra.*
recursive-include m2cgen log1p.*
recursive-include m2cgen tanh.*
recursive-include m2cgen atan.*
global-exclude *.py[cod]
122 changes: 116 additions & 6 deletions m2cgen/assemblers/fallback_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,16 @@ def tanh(expr):
tanh_expr))


def sqrt(expr, to_reuse=False):
def sqrt(expr):
return ast.PowExpr(
base_expr=expr,
exp_expr=ast.NumVal(0.5),
to_reuse=to_reuse)
exp_expr=ast.NumVal(0.5))


def exp(expr, to_reuse=False):
def exp(expr):
return ast.PowExpr(
base_expr=ast.NumVal(math.e),
exp_expr=expr,
to_reuse=to_reuse)
exp_expr=expr)


def log1p(expr):
Expand All @@ -66,6 +64,118 @@ def log1p(expr):
utils.div(utils.mul(expr, ast.LogExpr(expr1p)), expr1pm1))


def atan(expr):
expr = ast.IdExpr(expr, to_reuse=True)
expr_abs = ast.AbsExpr(expr, to_reuse=True)

expr_reduced = ast.IdExpr(
ast.IfExpr(
utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)),
utils.div(ast.NumVal(1.0), expr_abs),
ast.IfExpr(
utils.gt(expr_abs, ast.NumVal(0.66)),
utils.div(
utils.sub(expr_abs, ast.NumVal(1.0)),
utils.add(expr_abs, ast.NumVal(1.0))),
expr_abs)),
to_reuse=True)

P0 = ast.NumVal(-8.750608600031904122785e-01)
P1 = ast.NumVal(1.615753718733365076637e+01)
P2 = ast.NumVal(7.500855792314704667340e+01)
P3 = ast.NumVal(1.228866684490136173410e+02)
P4 = ast.NumVal(6.485021904942025371773e+01)
Q0 = ast.NumVal(2.485846490142306297962e+01)
Q1 = ast.NumVal(1.650270098316988542046e+02)
Q2 = ast.NumVal(4.328810604912902668951e+02)
Q3 = ast.NumVal(4.853903996359136964868e+02)
Q4 = ast.NumVal(1.945506571482613964425e+02)
expr2 = utils.mul(expr_reduced, expr_reduced, to_reuse=True)
z = utils.mul(
expr2,
utils.div(
utils.sub(
utils.mul(
expr2,
utils.sub(
utils.mul(
expr2,
utils.sub(
utils.mul(
expr2,
utils.sub(
utils.mul(
expr2,
P0
),
P1
)
),
P2
)
),
P3
)
),
P4
),
utils.add(
Q4,
utils.mul(
expr2,
utils.add(
Q3,
utils.mul(
expr2,
utils.add(
Q2,
utils.mul(
expr2,
utils.add(
Q1,
utils.mul(
expr2,
utils.add(
Q0,
expr2
)
)
)
)
)
)
)
)
)
)
)
z = utils.add(utils.mul(expr_reduced, z), expr_reduced)

ret = utils.mul(
z,
ast.IfExpr(
utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)),
ast.NumVal(-1.0),
ast.NumVal(1.0)))
ret = utils.add(
ret,
ast.IfExpr(
utils.lte(expr_abs, ast.NumVal(0.66)),
ast.NumVal(0.0),
ast.IfExpr(
utils.gt(expr_abs, ast.NumVal(2.4142135623730950488)),
ast.NumVal(1.570796326794896680463661649),
ast.NumVal(0.7853981633974483402318308245))))
ret = utils.mul(
ret,
ast.IfExpr(
utils.lt(expr, ast.NumVal(0.0)),
ast.NumVal(-1.0),
ast.NumVal(1.0)))

return ret


def sigmoid(expr, to_reuse=False):
neg_expr = ast.BinNumExpr(ast.NumVal(0.0), expr, ast.BinNumOpType.SUB)
exp_expr = ast.ExpExpr(neg_expr)
Expand Down
12 changes: 11 additions & 1 deletion m2cgen/assemblers/linear.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import numpy as np

from m2cgen import ast
Expand Down Expand Up @@ -149,6 +151,13 @@ def _negativebinomial_inversed(self, ast_to_transform):
ast.NumVal(-1.0),
utils.mul(ast.NumVal(alpha), res) if alpha != 1.0 else res)

def _cauchy_inversed(self, ast_to_transform):
return utils.add(
ast.NumVal(0.5),
utils.div(
ast.AtanExpr(ast_to_transform),
ast.NumVal(math.pi)))

def _get_power(self):
raise NotImplementedError

Expand All @@ -172,7 +181,8 @@ def _get_supported_inversed_funs(self):
"log": self._log_inversed,
"cloglog": self._cloglog_inversed,
"negativebinomial": self._negativebinomial_inversed,
"nbinom": self._negativebinomial_inversed
"nbinom": self._negativebinomial_inversed,
"cauchy": self._cauchy_inversed
}

def _get_power(self):
Expand Down
20 changes: 19 additions & 1 deletion m2cgen/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,23 @@ def __hash__(self):
return hash(self.expr)


class AtanExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"

self.expr = expr
self.to_reuse = to_reuse

def __str__(self):
return f"AtanExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is AtanExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class ExpExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -370,7 +387,8 @@ def __hash__(self):
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((AbsExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, SqrtExpr, TanhExpr),
((AbsExpr, AtanExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr,
SqrtExpr, TanhExpr),
lambda e: [e.expr]),
]

Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/c/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class CInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "fabs"
atan_function_name = "atan"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/c_sharp/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class CSharpInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "Abs"
atan_function_name = "Atan"
exponent_function_name = "Exp"
logarithm_function_name = "Log"
log1p_function_name = "Log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/dart/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class DartInterpreter(ImperativeToCodeInterpreter,
bin_depth_threshold = 465

abs_function_name = "abs"
atan_function_name = "atan"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/f_sharp/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class FSharpInterpreter(FunctionalToCodeInterpreter,
}

abs_function_name = "abs"
atan_function_name = "atan"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/go/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class GoInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "math.Abs"
atan_function_name = "math.Atan"
exponent_function_name = "math.Exp"
logarithm_function_name = "math.Log"
log1p_function_name = "math.Log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/haskell/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class HaskellInterpreter(FunctionalToCodeInterpreter,
}

abs_function_name = "abs"
atan_function_name = "atan"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
Expand Down
14 changes: 12 additions & 2 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class ToCodeInterpreter(BaseToCodeInterpreter):
"""

abs_function_name = NotImplemented
atan_function_name = NotImplemented
exponent_function_name = NotImplemented
logarithm_function_name = NotImplemented
log1p_function_name = NotImplemented
Expand Down Expand Up @@ -132,10 +133,19 @@ def interpret_abs_expr(self, expr, **kwargs):
return self._cg.function_invocation(
self.abs_function_name, nested_result)

def interpret_atan_expr(self, expr, **kwargs):
if self.atan_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.atan(expr.expr), **kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.atan_function_name, nested_result)

def interpret_exp_expr(self, expr, **kwargs):
if self.exponent_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.exp(expr.expr, to_reuse=expr.to_reuse),
fallback_expressions.exp(expr.expr),
**kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
Expand All @@ -162,7 +172,7 @@ def interpret_log1p_expr(self, expr, **kwargs):
def interpret_sqrt_expr(self, expr, **kwargs):
if self.sqrt_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.sqrt(expr.expr, to_reuse=expr.to_reuse),
fallback_expressions.sqrt(expr.expr),
**kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/java/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class JavaInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "Math.abs"
atan_function_name = "Math.atan"
exponent_function_name = "Math.exp"
logarithm_function_name = "Math.log"
log1p_function_name = "Math.log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/javascript/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class JavascriptInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "Math.abs"
atan_function_name = "Math.atan"
exponent_function_name = "Math.exp"
logarithm_function_name = "Math.log"
log1p_function_name = "Math.log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/php/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PhpInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "abs"
atan_function_name = "atan"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/powershell/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class PowershellInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "[math]::Abs"
atan_function_name = "[math]::Atan"
exponent_function_name = "[math]::Exp"
logarithm_function_name = "[math]::Log"
log1p_function_name = "Log1p"
Expand Down Expand Up @@ -63,6 +64,11 @@ def interpret_abs_expr(self, expr, **kwargs):
return self._cg.math_function_invocation(
self.abs_function_name, nested_result)

def interpret_atan_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.math_function_invocation(
self.atan_function_name, nested_result)

def interpret_exp_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.math_function_invocation(
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class PythonInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "abs"
atan_function_name = "math.atan"
exponent_function_name = "math.exp"
logarithm_function_name = "math.log"
log1p_function_name = "math.log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/r/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class RInterpreter(ImperativeToCodeInterpreter,
ast_size_per_subroutine_threshold = 200

abs_function_name = "abs"
atan_function_name = "atan"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/ruby/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class RubyInterpreter(ImperativeToCodeInterpreter,
}

abs_function_name = "abs"
atan_function_name = "Math.atan"
exponent_function_name = "Math.exp"
logarithm_function_name = "Math.log"
log1p_function_name = "log1p"
Expand Down
Loading

0 comments on commit 91ef2ef

Please sign in to comment.