-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] Add support for Formula corrections + tests
To do: - add support for parameters, unary ops - reduce code duplication
- Loading branch information
Showing
4 changed files
with
294 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum, auto | ||
from typing import TypeAlias | ||
|
||
import correctionlib.schemav2 as schema | ||
import jax.numpy as jnp # atan2 | ||
from correctionlib._core import Formula, FormulaAst | ||
from correctionlib._core import Variable as CPPVariable | ||
|
||
from correctionlib_gradients._typedefs import Value | ||
|
||
|
||
@dataclass | ||
class Literal: | ||
value: float | ||
|
||
|
||
@dataclass | ||
class Variable: | ||
name: str | ||
|
||
|
||
@dataclass | ||
class Parameter: | ||
idx: int | ||
|
||
|
||
class BinaryOp(Enum): | ||
EQUAL = auto() | ||
NOTEQUAL = auto() | ||
GREATER = auto() | ||
LESS = auto() | ||
GREATEREQ = auto() | ||
LESSEQ = auto() | ||
MINUS = auto() | ||
PLUS = auto() | ||
DIV = auto() | ||
TIMES = auto() | ||
POW = auto() | ||
ATAN2 = auto() | ||
MAX = auto() | ||
MIN = auto() | ||
|
||
|
||
class UnaryOp(Enum): | ||
NEGATIVE = auto() | ||
LOG = auto() | ||
LOG10 = auto() | ||
EXP = auto() | ||
ERF = auto() | ||
SQRT = auto() | ||
ABS = auto() | ||
COS = auto() | ||
SIN = auto() | ||
TAN = auto() | ||
ACOS = auto() | ||
ASIN = auto() | ||
ATAN = auto() | ||
COSH = auto() | ||
SINH = auto() | ||
TANH = auto() | ||
ACOSH = auto() | ||
ASINH = auto() | ||
ATANH = auto() | ||
|
||
|
||
FormulaNode: TypeAlias = Literal | Variable | Parameter | BinaryOp | UnaryOp | ||
|
||
|
||
@dataclass | ||
class Op: | ||
op: BinaryOp | UnaryOp | ||
children: tuple[FormulaNode] | ||
|
||
|
||
class FormulaDAG: | ||
def __init__(self, f: schema.Formula, inputs: list[schema.Variable]): | ||
cpp_formula = Formula.from_string(f.json(), [CPPVariable.from_string(v.json()) for v in inputs]) | ||
self.input_names = [v.name for v in inputs] | ||
self.node: FormulaNode = self._make_node(cpp_formula.ast) | ||
|
||
def evaluate(self, inputs: dict[str, Value]) -> Value: | ||
self.inputs = inputs | ||
res = self._eval_node(self.node) | ||
del self.inputs | ||
return res | ||
|
||
def _eval_node(self, node: FormulaNode) -> Value: | ||
match node: | ||
case Literal(value): | ||
return value | ||
case Variable(name): | ||
return self.inputs[name] | ||
case Op(op=op, children=children): | ||
match op: | ||
case BinaryOp.EQUAL: | ||
return self._eval_node(children[0]) == self._eval_node(children[1]) | ||
case BinaryOp.NOTEQUAL: | ||
return self._eval_node(children[0]) != self._eval_node(children[1]) | ||
case BinaryOp.GREATER: | ||
return self._eval_node(children[0]) > self._eval_node(children[1]) | ||
case BinaryOp.LESS: | ||
return self._eval_node(children[0]) < self._eval_node(children[1]) | ||
case BinaryOp.GREATEREQ: | ||
return self._eval_node(children[0]) >= self._eval_node(children[1]) | ||
case BinaryOp.LESSEQ: | ||
return self._eval_node(children[0]) <= self._eval_node(children[1]) | ||
case BinaryOp.MINUS: | ||
return self._eval_node(children[0]) - self._eval_node(children[1]) | ||
case BinaryOp.PLUS: | ||
return self._eval_node(children[0]) + self._eval_node(children[1]) | ||
case BinaryOp.DIV: | ||
return self._eval_node(children[0]) / self._eval_node(children[1]) | ||
case BinaryOp.TIMES: | ||
return self._eval_node(children[0]) * self._eval_node(children[1]) | ||
case BinaryOp.POW: | ||
return self._eval_node(children[0]) ** self._eval_node(children[1]) | ||
case BinaryOp.ATAN2: | ||
return jnp.atan2(self._eval_node(children[0]), self._eval_node(children[1])) | ||
case BinaryOp.MAX: | ||
return max(self._eval_node(children[0]), self._eval_node(children[1])) | ||
case BinaryOp.MIN: | ||
return min(self._eval_node(children[0]), self._eval_node(children[1])) | ||
case _: | ||
msg = f"Type of formula node not recognized ({node}). This should never happen." | ||
raise RuntimeError(msg) | ||
|
||
def _make_node(self, ast: FormulaAst) -> FormulaNode: | ||
match ast.nodetype: | ||
case FormulaAst.NodeType.LITERAL: | ||
return Literal(ast.data) | ||
case FormulaAst.NodeType.VARIABLE: | ||
return Variable(self.input_names[ast.data]) | ||
case FormulaAst.NodeType.BINARY: | ||
match ast.data: | ||
# TODO reduce code duplication (code generation?) | ||
case FormulaAst.BinaryOp.EQUAL: | ||
return Op( | ||
op=BinaryOp.EQUAL, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.NOTEQUAL: | ||
return Op( | ||
op=BinaryOp.NOTEQUAL, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.GREATER: | ||
return Op( | ||
op=BinaryOp.GREATER, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.LESS: | ||
return Op( | ||
op=BinaryOp.LESS, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.GREATEREQ: | ||
return Op( | ||
op=BinaryOp.GREATEREQ, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.LESSEQ: | ||
return Op( | ||
op=BinaryOp.LESSEQ, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.MINUS: | ||
return Op( | ||
op=BinaryOp.MINUS, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.PLUS: | ||
return Op( | ||
op=BinaryOp.PLUS, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.DIV: | ||
return Op( | ||
op=BinaryOp.DIV, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.TIMES: | ||
return Op( | ||
op=BinaryOp.TIMES, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.POW: | ||
return Op( | ||
op=BinaryOp.POW, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.ATAN2: | ||
return Op( | ||
op=BinaryOp.ATAN2, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.MAX: | ||
return Op( | ||
op=BinaryOp.MAX, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case FormulaAst.BinaryOp.MIN: | ||
return Op( | ||
op=BinaryOp.MIN, | ||
children=(self._make_node(ast.children[0]), self._make_node(ast.children[1])), | ||
) | ||
case _: | ||
msg = f"Type of formula node not recognized ({ast.nodetype.name}). This should never happen." | ||
raise ValueError(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import TypeAlias | ||
|
||
import jax | ||
import numpy as np | ||
|
||
# TODO: switch to use numpy.array_api.Array as _the_ array type. | ||
# Must wait for it to be out of experimental. | ||
# See https://numpy.org/doc/stable/reference/array_api.html. | ||
Array: TypeAlias = np.ndarray | jax.Array | ||
Value: TypeAlias = float | Array |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters