Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Formula corrections + tests #34

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 12 additions & 24 deletions src/correctionlib_gradients/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import correctionlib.schemav2 as schema
import jax
import jax.numpy as jnp
import numpy as np
from scipy.interpolate import CubicSpline # type: ignore[import-untyped]

import correctionlib_gradients._utils as utils
from correctionlib_gradients._formuladag import FormulaDAG
from correctionlib_gradients._typedefs import Value


Expand Down Expand Up @@ -41,7 +44,7 @@ def eval_spline_bwd(res, g): # type: ignore[no-untyped-def]
return cast(Callable[[Value], Value], eval_spline)


DAGNode: TypeAlias = float | schema.Binning
DAGNode: TypeAlias = float | schema.Binning | FormulaDAG


class CorrectionDAG:
Expand Down Expand Up @@ -75,19 +78,21 @@ def __init__(self, c: schema.Correction):
flow = cast(str, flow) # type: ignore[has-type]
msg = f"Correction '{c.name}' contains a Binning correction with `{flow=}`. Only 'clamp' is supported."
raise ValueError(msg)
case schema.Formula() as f:
self.node = FormulaDAG(f, c.inputs)
case _:
msg = f"Correction '{c.name}' contains the unsupported operation type '{type(c.data).__name__}'"
raise ValueError(msg)

def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
result_size = self._get_result_size(inputs)
result_size = utils.get_result_size(inputs)

match self.node:
case float(x):
if result_size == 0:
return jax.numpy.array(x)
return jnp.array(x)
else:
return jax.numpy.array([x] * result_size)
return jnp.repeat(x, result_size)
case schema.Binning(edges=_edges, content=[*_values], input=_var, flow="clamp"):
# to make mypy happy
var: str = _var # type: ignore[has-type]
Expand All @@ -100,29 +105,12 @@ def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
xs = np.array(edges)
s = make_differentiable_spline(xs, values)
return s(inputs[var])
case FormulaDAG() as f:
return f.evaluate(inputs)
case _: # pragma: no cover
msg = "Unsupported type of node in the computation graph. This should never happen."
raise RuntimeError(msg)

def _get_result_size(self, inputs: dict[str, jax.Array]) -> int:
"""Calculate what size the result of a DAG evaluation should have.
The size is equal to the one, common size (shape[0], or number or rows) of all
the non-scalar inputs we require, or 0 if all inputs are scalar.
An error is thrown in case the shapes of two non-scalar inputs differ.
"""
result_shape: tuple[int, ...] = ()
for value in inputs.values():
if result_shape == ():
result_shape = value.shape
elif value.shape != result_shape:
msg = "The shapes of all non-scalar inputs should match."
raise ValueError(msg)
if result_shape != ():
return result_shape[0]
else:
return 0


class CorrectionWithGradient:
def __init__(self, c: schema.Correction):
Expand All @@ -132,7 +120,7 @@ def __init__(self, c: schema.Correction):

def evaluate(self, *inputs: Value) -> jax.Array:
self._check_num_inputs(inputs)
inputs_as_jax = tuple(jax.numpy.array(i) for i in inputs)
inputs_as_jax = tuple(jnp.array(i) for i in inputs)
self._check_input_types(inputs_as_jax)
input_names = (v.name for v in self._input_vars)

Expand Down
224 changes: 224 additions & 0 deletions src/correctionlib_gradients/_formuladag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# SPDX-FileCopyrightText: 2023-present Enrico Guiraud <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import dataclass
from enum import Enum, auto
from typing import TypeAlias, Union

import correctionlib.schemav2 as schema
import jax
import jax.numpy as jnp # atan2
from correctionlib._core import Formula, FormulaAst
from correctionlib._core import Variable as CPPVariable

import correctionlib_gradients._utils as utils


@dataclass
class Literal:
value: float


@dataclass
class Variable:
name: str


@dataclass
class Parameter:
idx: int


class BinaryOp(Enum):
EQUAL = auto()
NOTEQUAL = auto()
GREATER = auto()
LESS = auto()
GREATEREQ = auto()
LESSEQ = auto()
MINUS = auto()
PLUS = auto()
DIV = auto()
TIMES = auto()
POW = auto()
ATAN2 = auto()
MAX = auto()
MIN = auto()


class UnaryOp(Enum):
NEGATIVE = auto()
LOG = auto()
LOG10 = auto()
EXP = auto()
ERF = auto()
SQRT = auto()
ABS = auto()
COS = auto()
SIN = auto()
TAN = auto()
ACOS = auto()
ASIN = auto()
ATAN = auto()
COSH = auto()
SINH = auto()
TANH = auto()
ACOSH = auto()
ASINH = auto()
ATANH = auto()


FormulaNode: TypeAlias = Union[Literal, Variable, Parameter, "Op"]


@dataclass
class Op:
op: BinaryOp | UnaryOp
children: tuple[FormulaNode, ...]


class FormulaDAG:
def __init__(self, f: schema.Formula, inputs: list[schema.Variable]):
cpp_formula = Formula.from_string(f.json(), [CPPVariable.from_string(v.json()) for v in inputs])
self.input_names = [v.name for v in inputs]
self.node: FormulaNode = self._make_node(cpp_formula.ast)

def evaluate(self, inputs: dict[str, jax.Array]) -> jax.Array:
res = self._eval_node(self.node, inputs)
return res

def _eval_node(self, node: FormulaNode, inputs: dict[str, jax.Array]) -> jax.Array:
match node:
case Literal(value):
res_size = utils.get_result_size(inputs)
if res_size == 0:
return jnp.array(value)
else:
return jnp.repeat(value, res_size)
case Variable(name):
return inputs[name]
case Op(op=BinaryOp(), children=children):
c1, c2 = children
ev = self._eval_node
i = inputs
match node.op:
case BinaryOp.EQUAL:
return (ev(c1, i) == ev(c2, i)) + 0.0
case BinaryOp.NOTEQUAL:
return (ev(c1, i) != ev(c2, i)) + 0.0
case BinaryOp.GREATER:
return (ev(c1, i) > ev(c2, i)) + 0.0
case BinaryOp.LESS:
return (ev(c1, i) < ev(c2, i)) + 0.0
case BinaryOp.GREATEREQ:
return (ev(c1, i) >= ev(c2, i)) + 0.0
case BinaryOp.LESSEQ:
return (ev(c1, i) <= ev(c2, i)) + 0.0
case BinaryOp.MINUS:
return ev(c1, i) - ev(c2, i)
case BinaryOp.PLUS:
return ev(c1, i) + ev(c2, i)
case BinaryOp.DIV:
return ev(c1, i) / ev(c2, i)
case BinaryOp.TIMES:
return ev(c1, i) * ev(c2, i)
case BinaryOp.POW:
return ev(c1, i) ** ev(c2, i)
case BinaryOp.ATAN2:
return jnp.arctan2(ev(c1, i), ev(c2, i))
case BinaryOp.MAX:
return jnp.max(jnp.stack([ev(c1, i), ev(c2, i)]))
case BinaryOp.MIN:
return jnp.min(jnp.stack([ev(c1, i), ev(c2, i)]))
case _: # pragma: no cover
msg = f"Type of formula node not recognized ({node}). This should never happen."
raise RuntimeError(msg)

# never reached, only here to make mypy happy
return jax.array() # pragma: no cover

def _make_node(self, ast: FormulaAst) -> FormulaNode:
match ast.nodetype:
case FormulaAst.NodeType.LITERAL:
return Literal(ast.data)
case FormulaAst.NodeType.VARIABLE:
return Variable(self.input_names[ast.data])
case FormulaAst.NodeType.BINARY:
match ast.data:
# TODO reduce code duplication (code generation?)
case FormulaAst.BinaryOp.EQUAL:
return Op(
op=BinaryOp.EQUAL,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.NOTEQUAL:
return Op(
op=BinaryOp.NOTEQUAL,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.GREATER:
return Op(
op=BinaryOp.GREATER,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.LESS:
return Op(
op=BinaryOp.LESS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.GREATEREQ:
return Op(
op=BinaryOp.GREATEREQ,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.LESSEQ:
return Op(
op=BinaryOp.LESSEQ,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MINUS:
return Op(
op=BinaryOp.MINUS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.PLUS:
return Op(
op=BinaryOp.PLUS,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.DIV:
return Op(
op=BinaryOp.DIV,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.TIMES:
return Op(
op=BinaryOp.TIMES,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.POW:
return Op(
op=BinaryOp.POW,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.ATAN2:
return Op(
op=BinaryOp.ATAN2,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MAX:
return Op(
op=BinaryOp.MAX,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case FormulaAst.BinaryOp.MIN:
return Op(
op=BinaryOp.MIN,
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])),
)
case _: # pragma: no cover
msg = f"Type of formula node not recognized ({ast.nodetype.name}). This should never happen."
raise ValueError(msg)

# never reached, just to make mypy happy
return Literal(0.0) # pragma: no cover
24 changes: 24 additions & 0 deletions src/correctionlib_gradients/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: 2023-present Enrico Guiraud <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
import jax


def get_result_size(inputs: dict[str, jax.Array]) -> int:
"""Calculate what size the result of a DAG evaluation should have.
The size is equal to the one, common size (shape[0], or number or rows) of all
the non-scalar inputs we require, or 0 if all inputs are scalar.
An error is thrown in case the shapes of two non-scalar inputs differ.
"""
result_shape: tuple[int, ...] = ()
for value in inputs.values():
if result_shape == ():
result_shape = value.shape
elif value.shape != result_shape:
msg = "The shapes of all non-scalar inputs should match."
raise ValueError(msg)
if result_shape != ():
return result_shape[0]
else:
return 0
Loading