diff --git a/src/correctionlib_gradients/_base.py b/src/correctionlib_gradients/_base.py index 1d2227f..b6fad94 100644 --- a/src/correctionlib_gradients/_base.py +++ b/src/correctionlib_gradients/_base.py @@ -9,11 +9,8 @@ import numpy as np from scipy.interpolate import CubicSpline # type: ignore[import-not-found] -# TODO: switch to use numpy.array_api.Array as _the_ array type. -# Must wait for it to be out of experimental. -# See https://numpy.org/doc/stable/reference/array_api.html. -Array: TypeAlias = np.ndarray | jax.Array -Value: TypeAlias = float | Array +from correctionlib_gradients._formuladag import FormulaDAG +from correctionlib_gradients._typedefs import Array, Value def midpoints(x: Array) -> Array: @@ -46,7 +43,7 @@ def eval_spline_bwd(res, g): # type: ignore[no-untyped-def] return cast(Callable[[Value], Value], eval_spline) -DAGNode: TypeAlias = float | schema.Binning +DAGNode: TypeAlias = float | schema.Binning | FormulaDAG @dataclass @@ -81,6 +78,8 @@ def __init__(self, c: schema.Correction): flow = cast(str, flow) # type: ignore[has-type] msg = f"Correction '{c.name}' contains a Binning correction with `{flow=}`. Only 'clamp' is supported." raise ValueError(msg) + case schema.Formula() as f: + self.node = FormulaDAG(f, c.inputs) case _: msg = f"Correction '{c.name}' contains the unsupported operation type '{type(c.data).__name__}'" raise ValueError(msg) @@ -101,6 +100,8 @@ def evaluate(self, inputs: dict[str, Value]) -> Value: xs = np.array(edges) s = make_differentiable_spline(xs, values) return s(inputs[var]) + case FormulaDAG() as f: + return f.evaluate(inputs) case _: # pragma: no cover msg = "Unsupported type of node in the computation graph. This should never happen." raise RuntimeError(msg) diff --git a/src/correctionlib_gradients/_formuladag.py b/src/correctionlib_gradients/_formuladag.py new file mode 100644 index 0000000..ef07231 --- /dev/null +++ b/src/correctionlib_gradients/_formuladag.py @@ -0,0 +1,209 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import TypeAlias + +import correctionlib.schemav2 as schema +import jax.numpy as jnp # atan2 +from correctionlib._core import Formula, FormulaAst +from correctionlib._core import Variable as CPPVariable + +from correctionlib_gradients._typedefs import Value + + +@dataclass +class Literal: + value: float + + +@dataclass +class Variable: + name: str + + +@dataclass +class Parameter: + idx: int + + +class BinaryOp(Enum): + EQUAL = auto() + NOTEQUAL = auto() + GREATER = auto() + LESS = auto() + GREATEREQ = auto() + LESSEQ = auto() + MINUS = auto() + PLUS = auto() + DIV = auto() + TIMES = auto() + POW = auto() + ATAN2 = auto() + MAX = auto() + MIN = auto() + + +class UnaryOp(Enum): + NEGATIVE = auto() + LOG = auto() + LOG10 = auto() + EXP = auto() + ERF = auto() + SQRT = auto() + ABS = auto() + COS = auto() + SIN = auto() + TAN = auto() + ACOS = auto() + ASIN = auto() + ATAN = auto() + COSH = auto() + SINH = auto() + TANH = auto() + ACOSH = auto() + ASINH = auto() + ATANH = auto() + + +FormulaNode: TypeAlias = Literal | Variable | Parameter | BinaryOp | UnaryOp + + +@dataclass +class Op: + op: BinaryOp | UnaryOp + children: tuple[FormulaNode] + + +class FormulaDAG: + def __init__(self, f: schema.Formula, inputs: list[schema.Variable]): + cpp_formula = Formula.from_string(f.json(), [CPPVariable.from_string(v.json()) for v in inputs]) + self.input_names = [v.name for v in inputs] + self.node: FormulaNode = self._make_node(cpp_formula.ast) + + def evaluate(self, inputs: dict[str, Value]) -> Value: + self.inputs = inputs + res = self._eval_node(self.node) + del self.inputs + return res + + def _eval_node(self, node: FormulaNode) -> Value: + match node: + case Literal(value): + return value + case Variable(name): + return self.inputs[name] + case Op(op=op, children=children): + match op: + case BinaryOp.EQUAL: + return self._eval_node(children[0]) == self._eval_node(children[1]) + case BinaryOp.NOTEQUAL: + return self._eval_node(children[0]) != self._eval_node(children[1]) + case BinaryOp.GREATER: + return self._eval_node(children[0]) > self._eval_node(children[1]) + case BinaryOp.LESS: + return self._eval_node(children[0]) < self._eval_node(children[1]) + case BinaryOp.GREATEREQ: + return self._eval_node(children[0]) >= self._eval_node(children[1]) + case BinaryOp.LESSEQ: + return self._eval_node(children[0]) <= self._eval_node(children[1]) + case BinaryOp.MINUS: + return self._eval_node(children[0]) - self._eval_node(children[1]) + case BinaryOp.PLUS: + return self._eval_node(children[0]) + self._eval_node(children[1]) + case BinaryOp.DIV: + return self._eval_node(children[0]) / self._eval_node(children[1]) + case BinaryOp.TIMES: + return self._eval_node(children[0]) * self._eval_node(children[1]) + case BinaryOp.POW: + return self._eval_node(children[0]) ** self._eval_node(children[1]) + case BinaryOp.ATAN2: + return jnp.atan2(self._eval_node(children[0]), self._eval_node(children[1])) + case BinaryOp.MAX: + return max(self._eval_node(children[0]), self._eval_node(children[1])) + case BinaryOp.MIN: + return min(self._eval_node(children[0]), self._eval_node(children[1])) + case _: + msg = f"Type of formula node not recognized ({node}). This should never happen." + raise RuntimeError(msg) + + def _make_node(self, ast: FormulaAst) -> FormulaNode: + match ast.nodetype: + case FormulaAst.NodeType.LITERAL: + return Literal(ast.data) + case FormulaAst.NodeType.VARIABLE: + return Variable(self.input_names[ast.data]) + case FormulaAst.NodeType.BINARY: + match ast.data: + # TODO reduce code duplication (code generation?) + case FormulaAst.BinaryOp.EQUAL: + return Op( + op=BinaryOp.EQUAL, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.NOTEQUAL: + return Op( + op=BinaryOp.NOTEQUAL, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.GREATER: + return Op( + op=BinaryOp.GREATER, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.LESS: + return Op( + op=BinaryOp.LESS, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.GREATEREQ: + return Op( + op=BinaryOp.GREATEREQ, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.LESSEQ: + return Op( + op=BinaryOp.LESSEQ, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.MINUS: + return Op( + op=BinaryOp.MINUS, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.PLUS: + return Op( + op=BinaryOp.PLUS, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.DIV: + return Op( + op=BinaryOp.DIV, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.TIMES: + return Op( + op=BinaryOp.TIMES, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.POW: + return Op( + op=BinaryOp.POW, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.ATAN2: + return Op( + op=BinaryOp.ATAN2, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.MAX: + return Op( + op=BinaryOp.MAX, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.MIN: + return Op( + op=BinaryOp.MIN, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case _: + msg = f"Type of formula node not recognized ({ast.nodetype.name}). This should never happen." + raise ValueError(msg) diff --git a/src/correctionlib_gradients/_typedefs.py b/src/correctionlib_gradients/_typedefs.py new file mode 100644 index 0000000..ec5b5ca --- /dev/null +++ b/src/correctionlib_gradients/_typedefs.py @@ -0,0 +1,10 @@ +from typing import TypeAlias + +import jax +import numpy as np + +# TODO: switch to use numpy.array_api.Array as _the_ array type. +# Must wait for it to be out of experimental. +# See https://numpy.org/doc/stable/reference/array_api.html. +Array: TypeAlias = np.ndarray | jax.Array +Value: TypeAlias = float | Array diff --git a/tests/test_base.py b/tests/test_base.py index cd67b6e..d0f61e9 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -74,6 +74,46 @@ flow=42.0, ), ), + "constant-formula": schemav2.Correction( + name="formula that returns a constant", + version=2, + inputs=[schemav2.Variable(name="x", type="real")], + output=schemav2.Variable(name="a scale", type="real"), + data=schemav2.Formula( + nodetype="formula", + expression="42.", + parser="TFormula", + variables=[], + ), + ), + "simple-formula": schemav2.Correction( + name="simple numerical expression", + version=2, + inputs=[schemav2.Variable(name="x", type="real")], + output=schemav2.Variable(name="a scale", type="real"), + data=schemav2.Formula( + nodetype="formula", + expression="x*x", + parser="TFormula", + variables=["x"], + ), + ), + "complex-formula": schemav2.Correction( + name="complex numerical expression that uses all available operations", + version=2, + inputs=[schemav2.Variable(name="x", type="real"), schemav2.Variable(name="y", type="real")], + output=schemav2.Variable(name="a scale", type="real"), + data=schemav2.Formula( + nodetype="formula", + # FIXME add unary ops, add parameters + expression=( + "(x == x) + (x != y) + (x < y) = (x > y) + (x <= y) + (x >= y)" + "- x/y + x*y + pow(x, 2) + atan2(x, y) + max(x, y) + min(x, y)" + ), + parser="TFormula", + variables=["x"], + ), + ), # this type of correction is unsupported "categorical": schemav2.Correction( name="categorical", @@ -188,3 +228,31 @@ def test_vectorized_evaluate_simple_nonuniform_binning(): grads = np.vectorize(jax.grad(cg.evaluate))(x) expected_grad = [0.794444444, 0.0, 0.0] assert np.allclose(grads, expected_grad) + + +@pytest.mark.parametrize("jit", [False, True]) +def test_constant_formula(jit): + cg = CorrectionWithGradient(schemas["constant-formula"]) + evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate + assert jax.value_and_grad(evaluate)(0.0) == (42.0, 0.0) + + +@pytest.mark.parametrize("jit", [False, True]) +def test_simple_formula(jit): + cg = CorrectionWithGradient(schemas["simple-formula"]) + eval_dict = jax.jit(cg.eval_dict) if jit else cg.eval_dict + value, grads = jax.value_and_grad(eval_dict)({"x": 2.0}) + assert math.isclose(value, 4.0) + assert math.isclose(grads["x"], 4.0) + + +@pytest.mark.parametrize("jit", [False, True]) +def test_simple_formula_vectorized(jit): + cg = CorrectionWithGradient(schemas["simple-formula"]) + evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate + # pass in different kinds of arrays/collections + # FIXME passing in a list does not work + for x in np.arange(1, 4, dtype=np.float32), jax.numpy.arange(1, 4, dtype=np.float32): + values, grads = np.vectorize(jax.value_and_grad(evaluate))(x) + assert np.allclose(values, [1.0, 4.0, 9.0]) + assert np.allclose(grads, [2.0, 4.0, 6.0])