Skip to content

Commit 644a68a

Browse files
author
wangb
committed
update autodiff
1 parent d860c7e commit 644a68a

File tree

4 files changed

+416
-0
lines changed

4 files changed

+416
-0
lines changed

autodiff/.gitignore

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
MANIFEST
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
.pytest_cache/
49+
50+
# Translations
51+
*.mo
52+
*.pot
53+
54+
# Django stuff:
55+
*.log
56+
.static_storage/
57+
.media/
58+
local_settings.py
59+
60+
# Flask stuff:
61+
instance/
62+
.webassets-cache
63+
64+
# Scrapy stuff:
65+
.scrapy
66+
67+
# Sphinx documentation
68+
docs/_build/
69+
70+
# PyBuilder
71+
target/
72+
73+
# Jupyter Notebook
74+
.ipynb_checkpoints
75+
76+
# pyenv
77+
.python-version
78+
79+
# celery beat schedule file
80+
celerybeat-schedule
81+
82+
# SageMath parsed files
83+
*.sage.py
84+
85+
# Environments
86+
.env
87+
.venv
88+
env/
89+
venv/
90+
ENV/
91+
env.bak/
92+
venv.bak/
93+
94+
# Spyder project settings
95+
.spyderproject
96+
.spyproject
97+
98+
# Rope project settings
99+
.ropeproject
100+
101+
# mkdocs documentation
102+
/site
103+
104+
# mypy
105+
.mypy_cache/

autodiff/.vscode/settings.json

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"python.unitTest.pyTestArgs": [
3+
"."
4+
],
5+
"python.unitTest.pyTestEnabled": true
6+
}

autodiff/autodiff.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from collections import namedtuple
2+
import math
3+
4+
Point = "Dict[str, float]"
5+
6+
class Expr:
7+
def eval(self, point: Point) -> float:
8+
""" Evaluate the expr @ the given point.
9+
:param point: Dict[str, float]. Maps variable names to their value
10+
:returns float:
11+
"""
12+
cache = {}
13+
return self._eval(point, cache)
14+
15+
def _eval(self, point: Point, cache: dict) -> float:
16+
""" Fills out a cache mapping Expr objects to their evaluated value.
17+
We can't just use functools.lru_cache here because Point (dictionaries)
18+
aren't hashable.
19+
"""
20+
raise NotImplementedError
21+
22+
def forward_diff(self, direction: Point, point: Point) -> float:
23+
""" Evaulate the directional derivative of a direction @ a point via
24+
forward-mode automatic differentiation
25+
:param point: Dict[str, float]. Maps variable names to their value
26+
:param direction: Dict[str, float]. Maps variable names to their value
27+
:returns float:
28+
"""
29+
cache = {}
30+
self._eval(point, cache)
31+
return self._forward_diff(direction, point, cache)
32+
33+
def _forward_diff(self, direction: Point, point: Point, cache) -> float:
34+
raise NotImplementedError
35+
36+
def reverse_diff(self, point: Point) -> Point:
37+
""" Evaulate the gradient of a direction @ a point via
38+
reverse-mode automatic differentiation.
39+
Internally dispatches to subclass-specific `_reverse_diff`
40+
:param point: Dict[str, float]. Maps variable names to their value
41+
:returns Dict[str, float]: Returns gradient @ point
42+
"""
43+
cache = {}
44+
self._eval(point, cache)
45+
x = { key: 0 for key in point }
46+
self._reverse_diff(point, 1, x, cache)
47+
return x
48+
49+
def _reverse_diff(self, point: Point, adjoint: float,
50+
gradient: Point, cache):
51+
raise NotImplementedError
52+
53+
def __add__(self, other):
54+
return Add(self, other)
55+
56+
def __sub__(self, other):
57+
return Subtract(self, other)
58+
59+
def __mul__(self, other):
60+
return Multiply(self, other)
61+
62+
def __truediv__(self, other):
63+
return Divide(self, other)
64+
65+
def __pow__(self, other):
66+
return Pow(self, other)
67+
68+
class Variable(Expr, namedtuple("Variable", ["name"])):
69+
def _eval(self, point, cache):
70+
cache[id(self)] = point[self.name]
71+
return point[self.name]
72+
73+
def _forward_diff(self, direction, point, cache):
74+
return direction[self.name]
75+
76+
def _reverse_diff(self, point, adjoint, gradient, cache):
77+
gradient[self.name] += adjoint
78+
79+
class Constant(Expr, namedtuple("Constant", ["value"])):
80+
def _eval(self, point, cache):
81+
cache[id(self)] = self.value
82+
return self.value
83+
84+
def _forward_diff(self, direction, point, cache):
85+
return 0
86+
87+
def _reverse_diff(self, point, ajoint, gradient, cache):
88+
pass
89+
90+
class Add(Expr, namedtuple("Add", ["expr1", "expr2"])):
91+
def _eval(self, point, cache):
92+
if id(self) not in cache:
93+
eval1, eval2 = self.expr1._eval, self.expr2._eval
94+
cache[id(self)] = eval1(point, cache) + eval2(point, cache)
95+
return cache[id(self)]
96+
97+
def _forward_diff(self, direction, point, cache):
98+
return self.expr1._forward_diff(direction, point, cache) + self.expr2._forward_diff(direction, point, cache)
99+
100+
def _reverse_diff(self, point, adjoint, gradient, cache):
101+
self.expr1._reverse_diff(point, adjoint, gradient, cache)
102+
self.expr2._reverse_diff(point, adjoint, gradient, cache)
103+
104+
class Subtract(Expr, namedtuple("Subtract", ["expr1", "expr2"])):
105+
def _eval(self, point, cache):
106+
if id(self) not in cache:
107+
eval1, eval2 = self.expr1._eval, self.expr2._eval
108+
cache[id(self)] = eval1(point, cache) - eval2(point, cache)
109+
return cache[id(self)]
110+
111+
def _forward_diff(self, direction, point, cache):
112+
return self.expr1._forward_diff(direction, point, cache) - self.expr2._forward_diff(direction, point, cache)
113+
114+
def _reverse_diff(self, point, adjoint, gradient, cache):
115+
self.expr1._reverse_diff(point, adjoint, gradient, cache)
116+
self.expr2._reverse_diff(point, -adjoint, gradient, cache)
117+
118+
class Multiply(Expr, namedtuple("Multiply", ["expr1", "expr2"])):
119+
def _eval(self, point, cache):
120+
if id(self) not in cache:
121+
eval1, eval2 = self.expr1._eval, self.expr2._eval
122+
cache[id(self)] = eval1(point, cache) * eval2(point, cache)
123+
return cache[id(self)]
124+
125+
def _forward_diff(self, direction, point, cache):
126+
return self.expr2._eval(point, cache) * self.expr1._forward_diff(direction, point, cache) + self.expr1._eval(point, cache) * self.expr2._forward_diff(direction, point, cache)
127+
128+
def _reverse_diff(self, point, adjoint, gradient, cache):
129+
self.expr1._reverse_diff(point, adjoint * self.expr2._eval(point, cache), gradient, cache)
130+
self.expr2._reverse_diff(point, adjoint * self.expr1._eval(point, cache), gradient, cache)
131+
132+
class Divide(Expr, namedtuple("Divide", ["expr1", "expr2"])):
133+
def _eval(self, point, cache):
134+
if id(self) not in cache:
135+
eval1, eval2 = self.expr1._eval, self.expr2._eval
136+
cache[id(self)] = eval1(point, cache) / eval2(point, cache)
137+
return cache[id(self)]
138+
139+
def _forward_diff(self, direction, point, cache):
140+
high = cache[id(self.expr1)]
141+
low = cache[id(self.expr2)]
142+
dhigh = self.expr1._forward_diff(direction, point, cache)
143+
dlow = self.expr2._forward_diff(direction, point, cache)
144+
145+
return (low * dhigh - high * dlow) / low ** 2
146+
147+
def _reverse_diff(self, point, adjoint, gradient, cache):
148+
high = cache[id(self.expr1)]
149+
low = cache[id(self.expr2)]
150+
self.expr1._reverse_diff(point, adjoint / low, gradient, cache)
151+
self.expr2._reverse_diff(point, -adjoint * high / low ** 2, gradient,
152+
cache)
153+
154+
class Pow(Expr, namedtuple("Pow", ["expr1", "expr2"])):
155+
def _eval(self, point, cache):
156+
if id(self) not in cache:
157+
eval1, eval2 = self.expr1._eval, self.expr2._eval
158+
cache[id(self)] = eval1(point, cache) ** eval2(point, cache)
159+
return cache[id(self)]
160+
161+
def _forward_diff(self, direction, point, cache):
162+
base = cache[id(self.expr1)]
163+
exp = cache[id(self.expr2)]
164+
dbase = self.expr1._forward_diff(direction, point, cache)
165+
dexp = self.expr2._forward_diff(direction, point, cache)
166+
167+
if base == 0: # avoid MathDomainError
168+
return 0
169+
else:
170+
# D_x[f ** g] = D_x[exp(g ln f)] = exp(g ln f) D_x[g ln f]
171+
# = exp(g ln f) (g'ln f + gf' / f)
172+
# = f ** (g - 1) (fg' ln f + gf')
173+
174+
return (base ** (exp - 1) *
175+
(exp * dbase + base * dexp * math.log(base)))
176+
177+
178+
def _reverse_diff(self, point, adjoint, gradient, cache):
179+
base = cache[id(self.expr1)]
180+
exp = cache[id(self.expr2)]
181+
182+
self.expr1._reverse_diff(point, adjoint * exp * base ** (exp - 1),
183+
gradient, cache)
184+
self.expr2._reverse_diff(point, adjoint * math.log(base) * base ** exp,
185+
gradient, cache)

0 commit comments

Comments
 (0)