Skip to content

Commit

Permalink
Add support for unary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
eguiraud committed Nov 3, 2023
1 parent 660781a commit ac10d37
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 11 deletions.
90 changes: 85 additions & 5 deletions src/correctionlib_gradients/_formuladag.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def _eval_node(self, node: FormulaNode, inputs: dict[str, jax.Array]) -> jax.Arr
return jnp.repeat(value, res_size)
case Variable(name):
return inputs[name]
case Op(op=BinaryOp(), children=children):
c1, c2 = children
case Op(op=BinaryOp(), children=(c1, c2)):
ev = self._eval_node
i = inputs
match node.op:
Expand Down Expand Up @@ -127,15 +126,56 @@ def _eval_node(self, node: FormulaNode, inputs: dict[str, jax.Array]) -> jax.Arr
case BinaryOp.ATAN2:
return jnp.arctan2(ev(c1, i), ev(c2, i))
case BinaryOp.MAX:
return jnp.max(jnp.stack([ev(c1, i), ev(c2, i)]))
return jnp.max(jnp.stack([ev(c1, i), ev(c2, i)]), axis=0)
case BinaryOp.MIN:
return jnp.min(jnp.stack([ev(c1, i), ev(c2, i)]))
return jnp.min(jnp.stack([ev(c1, i), ev(c2, i)]), axis=0)
case Op(op=UnaryOp(), children=(child,)):
ev = self._eval_node
match node.op:
case UnaryOp.NEGATIVE:
return -ev(child, inputs)
case UnaryOp.LOG:
return jnp.log(ev(child, inputs))
case UnaryOp.LOG10:
return jnp.log10(ev(child, inputs))
case UnaryOp.EXP:
return jnp.exp(ev(child, inputs))
case UnaryOp.ERF:
return jax.scipy.special.erf(ev(child, inputs))
case UnaryOp.SQRT:
return jnp.sqrt(ev(child, inputs))
case UnaryOp.ABS:
return jnp.abs(ev(child, inputs))
case UnaryOp.COS:
return jnp.cos(ev(child, inputs))
case UnaryOp.SIN:
return jnp.sin(ev(child, inputs))
case UnaryOp.TAN:
return jnp.tan(ev(child, inputs))
case UnaryOp.ACOS:
return jnp.arccos(ev(child, inputs))
case UnaryOp.ASIN:
return jnp.arcsin(ev(child, inputs))
case UnaryOp.ATAN:
return jnp.arctan(ev(child, inputs))
case UnaryOp.COSH:
return jnp.cosh(ev(child, inputs))
case UnaryOp.SINH:
return jnp.sinh(ev(child, inputs))
case UnaryOp.TANH:
return jnp.tanh(ev(child, inputs))
case UnaryOp.ACOSH:
return jnp.arccosh(ev(child, inputs))
case UnaryOp.ASINH:
return jnp.arcsinh(ev(child, inputs))
case UnaryOp.ATANH:
return jnp.arctanh(ev(child, inputs))
case _: # pragma: no cover
msg = f"Type of formula node not recognized ({node}). This should never happen."
raise RuntimeError(msg)

# never reached, only here to make mypy happy
return jax.array() # pragma: no cover
return jnp.array([]) # pragma: no cover

def _make_node(self, ast: FormulaAst) -> FormulaNode:
match ast.nodetype:
Expand Down Expand Up @@ -216,6 +256,46 @@ def _make_node(self, ast: FormulaAst) -> FormulaNode:
op=BinaryOp.MIN,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.NodeType.UNARY:
match ast.data:
case FormulaAst.UnaryOp.NEGATIVE:
return Op(op=UnaryOp.NEGATIVE, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.LOG:
return Op(op=UnaryOp.LOG, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.LOG10:
return Op(op=UnaryOp.LOG10, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.EXP:
return Op(op=UnaryOp.EXP, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ERF:
return Op(op=UnaryOp.ERF, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.SQRT:
return Op(op=UnaryOp.SQRT, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ABS:
return Op(op=UnaryOp.ABS, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.COS:
return Op(op=UnaryOp.COS, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.SIN:
return Op(op=UnaryOp.SIN, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.TAN:
return Op(op=UnaryOp.TAN, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ACOS:
return Op(op=UnaryOp.ACOS, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ASIN:
return Op(op=UnaryOp.ASIN, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ATAN:
return Op(op=UnaryOp.ATAN, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.COSH:
return Op(op=UnaryOp.COSH, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.SINH:
return Op(op=UnaryOp.SINH, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.TANH:
return Op(op=UnaryOp.TANH, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ACOSH:
return Op(op=UnaryOp.ACOSH, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ASINH:
return Op(op=UnaryOp.ASINH, children=(self._make_node(ast.children[0]),))
case FormulaAst.UnaryOp.ATANH:
return Op(op=UnaryOp.ATANH, children=(self._make_node(ast.children[0]),))
case _: # pragma: no cover
msg = f"Type of formula node not recognized ({ast.nodetype.name}). This should never happen."
raise ValueError(msg)
Expand Down
14 changes: 8 additions & 6 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula",
# FIXME add unary ops, add parameters
expression=(
"(x == x) + (x != y) + (x > y) + (x < y) + (x >= y) + (x <= y)"
"- x/y + x*y + pow(x, 2) + atan2(x, y) + max(x, y) + min(x, y)"
"+ (-x) + log(x) + log10(x) + exp(x) + erf(x) + sqrt(x) + abs(x)"
"+ cos(x) + sin(x) + tan(x) + acos(x / y) + asin(x / y) + atan(x) + cosh(x)"
"+ sinh(x) + tanh(x) + acosh(x * y) + asinh(x) + atanh(x / y)"
),
parser="TFormula",
variables=["x", "y"],
Expand Down Expand Up @@ -364,20 +366,20 @@ def test_complex_formula(jit):
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
value, grads = jax.value_and_grad(evaluate, argnums=[0, 1])(1.0, 2.0)
assert value.shape == ()
assert math.isclose(value, 9.963647609000805)
assert math.isclose(value, 26.047519582032493, abs_tol=1e-6)
assert len(grads) == 2
assert np.allclose(grads, [4.9, 2.05])
assert np.allclose(grads, [19.25876411, 2.29401694])


def test_complex_formula_nojax():
cg = CorrectionWithGradient(schemas["complex-formula"])
value = cg.evaluate(1.0, 2.0)
assert value.shape == ()
assert math.isclose(value, 9.963647609000805)
assert math.isclose(value, 26.047519582032493, abs_tol=1e-6)

values = cg.evaluate([1.0, 2.0], [2.0, 1.0])
values = cg.evaluate([1.0, 2.0], [2.0, 3.0])
assert len(values) == 2
assert np.allclose(values, [9.963647609000805, 12.107149])
assert np.allclose(values, [26.047519582032493, 43.77948741392216])


# TODO this does not work, seemingly because of np.vectorize
Expand Down

0 comments on commit ac10d37

Please sign in to comment.