diff --git a/src/correctionlib_gradients/_base.py b/src/correctionlib_gradients/_base.py index 00b8a1b..57b805d 100644 --- a/src/correctionlib_gradients/_base.py +++ b/src/correctionlib_gradients/_base.py @@ -5,9 +5,12 @@ import correctionlib.schemav2 as schema import jax +import jax.numpy as jnp import numpy as np from scipy.interpolate import CubicSpline # type: ignore[import-untyped] +import correctionlib_gradients._utils as utils +from correctionlib_gradients._formuladag import FormulaDAG from correctionlib_gradients._typedefs import Value @@ -41,7 +44,7 @@ def eval_spline_bwd(res, g): # type: ignore[no-untyped-def] return cast(Callable[[Value], Value], eval_spline) -DAGNode: TypeAlias = float | schema.Binning +DAGNode: TypeAlias = float | schema.Binning | FormulaDAG class CorrectionDAG: @@ -75,19 +78,21 @@ def __init__(self, c: schema.Correction): flow = cast(str, flow) # type: ignore[has-type] msg = f"Correction '{c.name}' contains a Binning correction with `{flow=}`. Only 'clamp' is supported." raise ValueError(msg) + case schema.Formula() as f: + self.node = FormulaDAG(f, c.inputs) case _: msg = f"Correction '{c.name}' contains the unsupported operation type '{type(c.data).__name__}'" raise ValueError(msg) def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array: - result_size = self._get_result_size(inputs) + result_size = utils.get_result_size(inputs) match self.node: case float(x): if result_size == 0: - return jax.numpy.array(x) + return jnp.array(x) else: - return jax.numpy.array([x] * result_size) + return jnp.repeat(x, result_size) case schema.Binning(edges=_edges, content=[*_values], input=_var, flow="clamp"): # to make mypy happy var: str = _var # type: ignore[has-type] @@ -100,29 +105,12 @@ def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array: xs = np.array(edges) s = make_differentiable_spline(xs, values) return s(inputs[var]) + case FormulaDAG() as f: + return f.evaluate(inputs) case _: # pragma: no cover msg = "Unsupported type of node in the computation graph. This should never happen." raise RuntimeError(msg) - def _get_result_size(self, inputs: dict[str, jax.Array]) -> int: - """Calculate what size the result of a DAG evaluation should have. - - The size is equal to the one, common size (shape[0], or number or rows) of all - the non-scalar inputs we require, or 0 if all inputs are scalar. - An error is thrown in case the shapes of two non-scalar inputs differ. - """ - result_shape: tuple[int, ...] = () - for value in inputs.values(): - if result_shape == (): - result_shape = value.shape - elif value.shape != result_shape: - msg = "The shapes of all non-scalar inputs should match." - raise ValueError(msg) - if result_shape != (): - return result_shape[0] - else: - return 0 - class CorrectionWithGradient: def __init__(self, c: schema.Correction): @@ -132,7 +120,7 @@ def __init__(self, c: schema.Correction): def evaluate(self, *inputs: Value) -> jax.Array: self._check_num_inputs(inputs) - inputs_as_jax = tuple(jax.numpy.array(i) for i in inputs) + inputs_as_jax = tuple(jnp.array(i) for i in inputs) self._check_input_types(inputs_as_jax) input_names = (v.name for v in self._input_vars) diff --git a/src/correctionlib_gradients/_formuladag.py b/src/correctionlib_gradients/_formuladag.py new file mode 100644 index 0000000..1d0c8a4 --- /dev/null +++ b/src/correctionlib_gradients/_formuladag.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: 2023-present Enrico Guiraud +# +# SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass +from enum import Enum, auto +from typing import TypeAlias, Union + +import correctionlib.schemav2 as schema +import jax +import jax.numpy as jnp # atan2 +from correctionlib._core import Formula, FormulaAst +from correctionlib._core import Variable as CPPVariable + +import correctionlib_gradients._utils as utils + + +@dataclass +class Literal: + value: float + + +@dataclass +class Variable: + name: str + + +@dataclass +class Parameter: + idx: int + + +class BinaryOp(Enum): + EQUAL = auto() + NOTEQUAL = auto() + GREATER = auto() + LESS = auto() + GREATEREQ = auto() + LESSEQ = auto() + MINUS = auto() + PLUS = auto() + DIV = auto() + TIMES = auto() + POW = auto() + ATAN2 = auto() + MAX = auto() + MIN = auto() + + +class UnaryOp(Enum): + NEGATIVE = auto() + LOG = auto() + LOG10 = auto() + EXP = auto() + ERF = auto() + SQRT = auto() + ABS = auto() + COS = auto() + SIN = auto() + TAN = auto() + ACOS = auto() + ASIN = auto() + ATAN = auto() + COSH = auto() + SINH = auto() + TANH = auto() + ACOSH = auto() + ASINH = auto() + ATANH = auto() + + +FormulaNode: TypeAlias = Union[Literal, Variable, Parameter, "Op"] + + +@dataclass +class Op: + op: BinaryOp | UnaryOp + children: tuple[FormulaNode, ...] + + +class FormulaDAG: + def __init__(self, f: schema.Formula, inputs: list[schema.Variable]): + cpp_formula = Formula.from_string(f.json(), [CPPVariable.from_string(v.json()) for v in inputs]) + self.input_names = [v.name for v in inputs] + self.node: FormulaNode = self._make_node(cpp_formula.ast) + + def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array: + res = self._eval_node(self.node, inputs) + return res + + def _eval_node(self, node: FormulaNode, inputs: dict[str, jax.Array]) -> jax.Array: + match node: + case Literal(value): + res_size = utils.get_result_size(inputs) + if res_size == 0: + return jnp.array(value) + else: + return jnp.repeat(value, res_size) + case Variable(name): + return inputs[name] + case Op(op=BinaryOp(), children=children): + c1, c2 = children + ev = self._eval_node + i = inputs + match node.op: + case BinaryOp.EQUAL: + return (ev(c1, i) == ev(c2, i)) + 0.0 + case BinaryOp.NOTEQUAL: + return (ev(c1, i) != ev(c2, i)) + 0.0 + case BinaryOp.GREATER: + return (ev(c1, i) > ev(c2, i)) + 0.0 + case BinaryOp.LESS: + return (ev(c1, i) < ev(c2, i)) + 0.0 + case BinaryOp.GREATEREQ: + return (ev(c1, i) >= ev(c2, i)) + 0.0 + case BinaryOp.LESSEQ: + return (ev(c1, i) <= ev(c2, i)) + 0.0 + case BinaryOp.MINUS: + return ev(c1, i) - ev(c2, i) + case BinaryOp.PLUS: + return ev(c1, i) + ev(c2, i) + case BinaryOp.DIV: + return ev(c1, i) / ev(c2, i) + case BinaryOp.TIMES: + return ev(c1, i) * ev(c2, i) + case BinaryOp.POW: + return ev(c1, i) ** ev(c2, i) + case BinaryOp.ATAN2: + return jnp.arctan2(ev(c1, i), ev(c2, i)) + case BinaryOp.MAX: + return jnp.max(jnp.stack([ev(c1, i), ev(c2, i)])) + case BinaryOp.MIN: + return jnp.min(jnp.stack([ev(c1, i), ev(c2, i)])) + case _: + msg = f"Type of formula node not recognized ({node}). This should never happen." + raise RuntimeError(msg) + + return jax.array() # never reached + + def _make_node(self, ast: FormulaAst) -> FormulaNode: + match ast.nodetype: + case FormulaAst.NodeType.LITERAL: + return Literal(ast.data) + case FormulaAst.NodeType.VARIABLE: + return Variable(self.input_names[ast.data]) + case FormulaAst.NodeType.BINARY: + match ast.data: + # TODO reduce code duplication (code generation?) + case FormulaAst.BinaryOp.EQUAL: + return Op( + op=BinaryOp.EQUAL, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.NOTEQUAL: + return Op( + op=BinaryOp.NOTEQUAL, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.GREATER: + return Op( + op=BinaryOp.GREATER, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.LESS: + return Op( + op=BinaryOp.LESS, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.GREATEREQ: + return Op( + op=BinaryOp.GREATEREQ, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.LESSEQ: + return Op( + op=BinaryOp.LESSEQ, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.MINUS: + return Op( + op=BinaryOp.MINUS, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.PLUS: + return Op( + op=BinaryOp.PLUS, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.DIV: + return Op( + op=BinaryOp.DIV, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.TIMES: + return Op( + op=BinaryOp.TIMES, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.POW: + return Op( + op=BinaryOp.POW, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.ATAN2: + return Op( + op=BinaryOp.ATAN2, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.MAX: + return Op( + op=BinaryOp.MAX, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case FormulaAst.BinaryOp.MIN: + return Op( + op=BinaryOp.MIN, + children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), + ) + case _: # pragma: no cover + msg = f"Type of formula node not recognized ({ast.nodetype.name}). This should never happen." + raise ValueError(msg) + + # never reached, just to make mypy happy + return Literal(0.0) # pragma: no cover diff --git a/src/correctionlib_gradients/_utils.py b/src/correctionlib_gradients/_utils.py new file mode 100644 index 0000000..b88c18c --- /dev/null +++ b/src/correctionlib_gradients/_utils.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: 2023-present Enrico Guiraud +# +# SPDX-License-Identifier: BSD-3-Clause +import jax + + +def get_result_size(inputs: dict[str, jax.Array]) -> int: + """Calculate what size the result of a DAG evaluation should have. + + The size is equal to the one, common size (shape[0], or number or rows) of all + the non-scalar inputs we require, or 0 if all inputs are scalar. + An error is thrown in case the shapes of two non-scalar inputs differ. + """ + result_shape: tuple[int, ...] = () + for value in inputs.values(): + if result_shape == (): + result_shape = value.shape + elif value.shape != result_shape: + msg = "The shapes of all non-scalar inputs should match." + raise ValueError(msg) + if result_shape != (): + return result_shape[0] + else: + return 0 diff --git a/tests/test_base.py b/tests/test_base.py index 2519027..184ec22 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -4,6 +4,7 @@ import math import jax +import jax.numpy as jnp import numpy as np import pytest from correctionlib import schemav2 @@ -88,6 +89,46 @@ flow=42.0, ), ), + "constant-formula": schemav2.Correction( + name="formula that returns a constant", + version=2, + inputs=[schemav2.Variable(name="x", type="real")], + output=schemav2.Variable(name="a scale", type="real"), + data=schemav2.Formula( + nodetype="formula", + expression="42.", + parser="TFormula", + variables=[], + ), + ), + "simple-formula": schemav2.Correction( + name="simple numerical expression", + version=2, + inputs=[schemav2.Variable(name="x", type="real")], + output=schemav2.Variable(name="a scale", type="real"), + data=schemav2.Formula( + nodetype="formula", + expression="x*x", + parser="TFormula", + variables=["x"], + ), + ), + "complex-formula": schemav2.Correction( + name="complex numerical expression that uses all available operations", + version=2, + inputs=[schemav2.Variable(name="x", type="real"), schemav2.Variable(name="y", type="real")], + output=schemav2.Variable(name="a scale", type="real"), + data=schemav2.Formula( + nodetype="formula", + # FIXME add unary ops, add parameters + expression=( + "(x == x) + (x != y) + (x > y) + (x < y) + (x >= y) + (x <= y)" + "- x/y + x*y + pow(x, 2) + atan2(x, y) + max(x, y) + min(x, y)" + ), + parser="TFormula", + variables=["x", "y"], + ), + ), # this type of correction is unsupported "categorical": schemav2.Correction( name="categorical", @@ -210,13 +251,13 @@ def test_mixed_scalar_array_inputs(jit): assert len(grads) == 2 assert np.allclose(grads, [0.0, 0.0]) - values, grads = np.vectorize(jax.value_and_grad(evaluate))(jax.numpy.array(42.0), [1.234, 8.0]) + values, grads = np.vectorize(jax.value_and_grad(evaluate))(jnp.array(42.0), [1.234, 8.0]) assert len(values) == 2 assert np.allclose(values, [1.234, 1.234]) assert len(grads) == 2 assert np.allclose(grads, [0.0, 0.0]) - values, grads = np.vectorize(jax.value_and_grad(evaluate))(jax.numpy.array(42.0), jax.numpy.array([1.234, 8.0])) + values, grads = np.vectorize(jax.value_and_grad(evaluate))(jnp.array(42.0), jnp.array([1.234, 8.0])) assert len(values) == 2 assert np.allclose(values, [1.234, 1.234]) assert len(grads) == 2 @@ -229,11 +270,11 @@ def test_mixed_scalar_array_inputs_nojax(): assert len(values) == 2 assert np.allclose(values, [1.234, 1.234]) - values = cg.evaluate(jax.numpy.array(42.0), [1.234, 8.0]) + values = cg.evaluate(jnp.array(42.0), [1.234, 8.0]) assert len(values) == 2 assert np.allclose(values, [1.234, 1.234]) - values = cg.evaluate(jax.numpy.array(42.0), jax.numpy.array([1.234, 8.0])) + values = cg.evaluate(jnp.array(42.0), jnp.array([1.234, 8.0])) assert len(values) == 2 assert np.allclose(values, [1.234, 1.234]) @@ -266,3 +307,90 @@ def test_vectorized_evaluate_simple_nonuniform_binning(): expected_grad = [0.794444444, 0.0, 0.0] assert len(grads) == len(expected_grad) assert np.allclose(grads, expected_grad) + + +@pytest.mark.parametrize("jit", [False, True]) +def test_constant_formula(jit): + cg = CorrectionWithGradient(schemas["constant-formula"]) + evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate + assert jax.value_and_grad(evaluate)(0.0) == (42.0, 0.0) + + +def test_constant_formula_nojax(): + cg = CorrectionWithGradient(schemas["constant-formula"]) + res = cg.evaluate([0.0, 1.0]) + assert len(res) == 2 + assert jnp.array_equal(res, (42.0, 42.0)) + + +@pytest.mark.parametrize("jit", [False, True]) +def test_simple_formula(jit): + cg = CorrectionWithGradient(schemas["simple-formula"]) + evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate + value, grads = jax.value_and_grad(evaluate)(2.0) + assert value.shape == () + assert math.isclose(value, 4.0) + assert grads.shape == () + assert math.isclose(grads, 4.0) + + +def test_simple_formula_nojax(): + cg = CorrectionWithGradient(schemas["simple-formula"]) + value = cg.evaluate(2.0) + assert value.shape == () + assert math.isclose(value, 4.0) + + values = cg.evaluate([2.0, 4.0]) + assert len(values) == 2 + assert np.allclose(values, [4.0, 16.0]) + + +@pytest.mark.parametrize("jit", [False, True]) +def test_simple_formula_vectorized(jit): + cg = CorrectionWithGradient(schemas["simple-formula"]) + evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate + # pass in different kinds of arrays/collections + for x in [1.0, 2.0, 3.0], np.arange(1, 4, dtype=np.float32), jnp.arange(1, 4, dtype=np.float32): + values, grads = np.vectorize(jax.value_and_grad(evaluate))(x) + assert len(values) == 3 + assert np.allclose(values, [1.0, 4.0, 9.0]) + assert len(grads) == 3 + assert np.allclose(grads, [2.0, 4.0, 6.0]) + + +@pytest.mark.parametrize("jit", [False, True]) +def test_complex_formula(jit): + cg = CorrectionWithGradient(schemas["complex-formula"]) + evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate + value, grads = jax.value_and_grad(evaluate, argnums=[0, 1])(1.0, 2.0) + assert value.shape == () + assert math.isclose(value, 9.963647609000805) + assert len(grads) == 2 + assert np.allclose(grads, [4.9, 2.05]) + + +def test_complex_formula_nojax(): + cg = CorrectionWithGradient(schemas["complex-formula"]) + value = cg.evaluate(1.0, 2.0) + assert value.shape == () + assert math.isclose(value, 9.963647609000805) + + values = cg.evaluate([1.0, 2.0], [2.0, 1.0]) + assert len(values) == 2 + assert np.allclose(values, [9.963647609000805, 12.107149]) + + +# TODO this does not work, seemingly because of np.vectorize +# choking on the gradients being a tuple. +# @pytest.mark.parametrize("jit", [False, True]) +# def test_complex_formula_vectorized(jit): +# cg = CorrectionWithGradient(schemas["complex-formula"]) +# evaluate = jax.jit(cg.evaluate) if jit else cg.evaluate +# # pass in different kinds of arrays/collections +# y = jnp.array(2.) +# for x in [1.0, 2.0], np.array([1., 2.]), jnp.array([1., 2.]): +# values, grads = np.vectorize(jax.value_and_grad(evaluate, argnums=[0,1]))(x, y) +# assert len(values) == 2 +# assert np.allclose(values, [9.963647609000805, 14.7853985]) +# assert len(grads) == 88 +# assert np.allclose(grads, [2.0, 4.0, 6.0])