Skip to content

Commit

Permalink
[WIP] Add support for Formula corrections + tests
Browse files Browse the repository at this point in the history
To do:
  - add support for parameters, unary ops
  - reduce code duplication
  • Loading branch information
eguiraud committed Oct 23, 2023
1 parent 3d3cbc1 commit 1b33298
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/correctionlib_gradients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
import numpy as np
from scipy.interpolate import CubicSpline # type: ignore[import-not-found]

# TODO: switch to use numpy.array_api.Array as _the_ array type.
# Must wait for it to be out of experimental.
# See https://numpy.org/doc/stable/reference/array_api.html.
Array: TypeAlias = np.ndarray | jax.Array
Value: TypeAlias = float | Array
from correctionlib_gradients._formuladag import FormulaDAG
from correctionlib_gradients._typedefs import Array, Value


def midpoints(x: Array) -> Array:
Expand Down Expand Up @@ -46,7 +43,7 @@ def eval_spline_bwd(res, g): # type: ignore[no-untyped-def]
return cast(Callable[[Value], Value], eval_spline)


DAGNode: TypeAlias = float | schema.Binning
DAGNode: TypeAlias = float | schema.Binning | FormulaDAG


@dataclass
Expand Down Expand Up @@ -81,6 +78,8 @@ def __init__(self, c: schema.Correction):
flow = cast(str, flow) # type: ignore[has-type]
msg = f"Correction '{c.name}' contains a Binning correction with `{flow=}`. Only 'clamp' is supported."
raise ValueError(msg)
case schema.Formula() as f:
self.node = FormulaDAG(f, c.inputs)
case _:
msg = f"Correction '{c.name}' contains the unsupported operation type '{type(c.data).__name__}'"
raise ValueError(msg)
Expand All @@ -101,6 +100,8 @@ def evaluate(self, inputs: dict[str, Value]) -> Value:
xs = np.array(edges)
s = make_differentiable_spline(xs, values)
return s(inputs[var])
case FormulaDAG() as f:
return f.evaluate(inputs)
case _: # pragma: no cover
msg = "Unsupported type of node in the computation graph. This should never happen."
raise RuntimeError(msg)
Expand Down
209 changes: 209 additions & 0 deletions src/correctionlib_gradients/_formuladag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import TypeAlias

import correctionlib.schemav2 as schema
import jax.numpy as jnp # atan2
from correctionlib._core import Formula, FormulaAst
from correctionlib._core import Variable as CPPVariable

from correctionlib_gradients._typedefs import Value


@dataclass
class Literal:
value: float


@dataclass
class Variable:
name: str


@dataclass
class Parameter:
idx: int


class BinaryOp(Enum):
EQUAL = auto()
NOTEQUAL = auto()
GREATER = auto()
LESS = auto()
GREATEREQ = auto()
LESSEQ = auto()
MINUS = auto()
PLUS = auto()
DIV = auto()
TIMES = auto()
POW = auto()
ATAN2 = auto()
MAX = auto()
MIN = auto()


class UnaryOp(Enum):
NEGATIVE = auto()
LOG = auto()
LOG10 = auto()
EXP = auto()
ERF = auto()
SQRT = auto()
ABS = auto()
COS = auto()
SIN = auto()
TAN = auto()
ACOS = auto()
ASIN = auto()
ATAN = auto()
COSH = auto()
SINH = auto()
TANH = auto()
ACOSH = auto()
ASINH = auto()
ATANH = auto()


FormulaNode: TypeAlias = Literal | Variable | Parameter | BinaryOp | UnaryOp


@dataclass
class Op:
op: BinaryOp | UnaryOp
children: tuple[FormulaNode]


class FormulaDAG:
def __init__(self, f: schema.Formula, inputs: list[schema.Variable]):
cpp_formula = Formula.from_string(f.json(), [CPPVariable.from_string(v.json()) for v in inputs])
self.input_names = [v.name for v in inputs]
self.node: FormulaNode = self._make_node(cpp_formula.ast)

def evaluate(self, inputs: dict[str, Value]) -> Value:
self.inputs = inputs
res = self._eval_node(self.node)
del self.inputs
return res

def _eval_node(self, node: FormulaNode) -> Value:
match node:
case Literal(value):
return value
case Variable(name):
return self.inputs[name]
case Op(op=op, children=children):
match op:
case BinaryOp.EQUAL:
return self._eval_node(children[0]) == self._eval_node(children[1])
case BinaryOp.NOTEQUAL:
return self._eval_node(children[0]) != self._eval_node(children[1])
case BinaryOp.GREATER:
return self._eval_node(children[0]) > self._eval_node(children[1])
case BinaryOp.LESS:
return self._eval_node(children[0]) < self._eval_node(children[1])
case BinaryOp.GREATEREQ:
return self._eval_node(children[0]) >= self._eval_node(children[1])
case BinaryOp.LESSEQ:
return self._eval_node(children[0]) <= self._eval_node(children[1])
case BinaryOp.MINUS:
return self._eval_node(children[0]) - self._eval_node(children[1])
case BinaryOp.PLUS:
return self._eval_node(children[0]) + self._eval_node(children[1])
case BinaryOp.DIV:
return self._eval_node(children[0]) / self._eval_node(children[1])
case BinaryOp.TIMES:
return self._eval_node(children[0]) * self._eval_node(children[1])
case BinaryOp.POW:
return self._eval_node(children[0]) ** self._eval_node(children[1])
case BinaryOp.ATAN2:
return jnp.atan2(self._eval_node(children[0]), self._eval_node(children[1]))
case BinaryOp.MAX:
return max(self._eval_node(children[0]), self._eval_node(children[1]))
case BinaryOp.MIN:
return min(self._eval_node(children[0]), self._eval_node(children[1]))
case _:
msg = f"Type of formula node not recognized ({node}). This should never happen."
raise RuntimeError(msg)

def _make_node(self, ast: FormulaAst) -> FormulaNode:
match ast.nodetype:
case FormulaAst.NodeType.LITERAL:
return Literal(ast.data)
case FormulaAst.NodeType.VARIABLE:
return Variable(self.input_names[ast.data])
case FormulaAst.NodeType.BINARY:
match ast.data:
# TODO reduce code duplication (code generation?)
case FormulaAst.BinaryOp.EQUAL:
return Op(
op=BinaryOp.EQUAL,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.NOTEQUAL:
return Op(
op=BinaryOp.NOTEQUAL,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.GREATER:
return Op(
op=BinaryOp.GREATER,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.LESS:
return Op(
op=BinaryOp.LESS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.GREATEREQ:
return Op(
op=BinaryOp.GREATEREQ,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.LESSEQ:
return Op(
op=BinaryOp.LESSEQ,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MINUS:
return Op(
op=BinaryOp.MINUS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.PLUS:
return Op(
op=BinaryOp.PLUS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.DIV:
return Op(
op=BinaryOp.DIV,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.TIMES:
return Op(
op=BinaryOp.TIMES,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.POW:
return Op(
op=BinaryOp.POW,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.ATAN2:
return Op(
op=BinaryOp.ATAN2,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MAX:
return Op(
op=BinaryOp.MAX,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MIN:
return Op(
op=BinaryOp.MIN,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case _:
msg = f"Type of formula node not recognized ({ast.nodetype.name}). This should never happen."
raise ValueError(msg)
10 changes: 10 additions & 0 deletions src/correctionlib_gradients/_typedefs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import TypeAlias

import jax
import numpy as np

# TODO: switch to use numpy.array_api.Array as _the_ array type.
# Must wait for it to be out of experimental.
# See https://numpy.org/doc/stable/reference/array_api.html.
Array: TypeAlias = np.ndarray | jax.Array
Value: TypeAlias = float | Array
68 changes: 68 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,46 @@
flow=42.0,
),
),
"constant-formula": schemav2.Correction(
name="formula that returns a constant",
version=2,
inputs=[schemav2.Variable(name="x", type="real")],
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula",
expression="42.",
parser="TFormula",
variables=[],
),
),
"simple-formula": schemav2.Correction(
name="simple numerical expression",
version=2,
inputs=[schemav2.Variable(name="x", type="real")],
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula",
expression="x*x",
parser="TFormula",
variables=["x"],
),
),
"complex-formula": schemav2.Correction(
name="complex numerical expression that uses all available operations",
version=2,
inputs=[schemav2.Variable(name="x", type="real"), schemav2.Variable(name="y", type="real")],
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula",
# FIXME add unary ops, add parameters
expression=(
"(x == x) + (x != y) + (x < y) = (x > y) + (x <= y) + (x >= y)"
"- x/y + x*y + pow(x, 2) + atan2(x, y) + max(x, y) + min(x, y)"
),
parser="TFormula",
variables=["x"],
),
),
# this type of correction is unsupported
"categorical": schemav2.Correction(
name="categorical",
Expand Down Expand Up @@ -188,3 +228,31 @@ def test_vectorized_evaluate_simple_nonuniform_binning():
grads = np.vectorize(jax.grad(cg.evaluate))(x)
expected_grad = [0.794444444, 0.0, 0.0]
assert np.allclose(grads, expected_grad)


@pytest.mark.parametrize("jit", [False, True])
def test_constant_formula(jit):
cg = CorrectionWithGradient(schemas["constant-formula"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
assert jax.value_and_grad(evaluate)(0.0) == (42.0, 0.0)


@pytest.mark.parametrize("jit", [False, True])
def test_simple_formula(jit):
cg = CorrectionWithGradient(schemas["simple-formula"])
eval_dict = jax.jit(cg.eval_dict) if jit else cg.eval_dict
value, grads = jax.value_and_grad(eval_dict)({"x": 2.0})
assert math.isclose(value, 4.0)
assert math.isclose(grads["x"], 4.0)


@pytest.mark.parametrize("jit", [False, True])
def test_simple_formula_vectorized(jit):
cg = CorrectionWithGradient(schemas["simple-formula"])
evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate
# pass in different kinds of arrays/collections
# FIXME passing in a list does not work
for x in np.arange(1, 4, dtype=np.float32), jax.numpy.arange(1, 4, dtype=np.float32):
values, grads = np.vectorize(jax.value_and_grad(evaluate))(x)
assert np.allclose(values, [1.0, 4.0, 9.0])
assert np.allclose(grads, [2.0, 4.0, 6.0])

0 comments on commit 1b33298

Please sign in to comment.