1
+ from collections import namedtuple
2
+ import math
3
+
4
+ Point = "Dict[str, float]"
5
+
6
+ class Expr :
7
+ def eval (self , point : Point ) -> float :
8
+ """ Evaluate the expr @ the given point.
9
+ :param point: Dict[str, float]. Maps variable names to their value
10
+ :returns float:
11
+ """
12
+ cache = {}
13
+ return self ._eval (point , cache )
14
+
15
+ def _eval (self , point : Point , cache : dict ) -> float :
16
+ """ Fills out a cache mapping Expr objects to their evaluated value.
17
+ We can't just use functools.lru_cache here because Point (dictionaries)
18
+ aren't hashable.
19
+ """
20
+ raise NotImplementedError
21
+
22
+ def forward_diff (self , direction : Point , point : Point ) -> float :
23
+ """ Evaulate the directional derivative of a direction @ a point via
24
+ forward-mode automatic differentiation
25
+ :param point: Dict[str, float]. Maps variable names to their value
26
+ :param direction: Dict[str, float]. Maps variable names to their value
27
+ :returns float:
28
+ """
29
+ cache = {}
30
+ self ._eval (point , cache )
31
+ return self ._forward_diff (direction , point , cache )
32
+
33
+ def _forward_diff (self , direction : Point , point : Point , cache ) -> float :
34
+ raise NotImplementedError
35
+
36
+ def reverse_diff (self , point : Point ) -> Point :
37
+ """ Evaulate the gradient of a direction @ a point via
38
+ reverse-mode automatic differentiation.
39
+ Internally dispatches to subclass-specific `_reverse_diff`
40
+ :param point: Dict[str, float]. Maps variable names to their value
41
+ :returns Dict[str, float]: Returns gradient @ point
42
+ """
43
+ cache = {}
44
+ self ._eval (point , cache )
45
+ x = { key : 0 for key in point }
46
+ self ._reverse_diff (point , 1 , x , cache )
47
+ return x
48
+
49
+ def _reverse_diff (self , point : Point , adjoint : float ,
50
+ gradient : Point , cache ):
51
+ raise NotImplementedError
52
+
53
+ def __add__ (self , other ):
54
+ return Add (self , other )
55
+
56
+ def __sub__ (self , other ):
57
+ return Subtract (self , other )
58
+
59
+ def __mul__ (self , other ):
60
+ return Multiply (self , other )
61
+
62
+ def __truediv__ (self , other ):
63
+ return Divide (self , other )
64
+
65
+ def __pow__ (self , other ):
66
+ return Pow (self , other )
67
+
68
+ class Variable (Expr , namedtuple ("Variable" , ["name" ])):
69
+ def _eval (self , point , cache ):
70
+ cache [id (self )] = point [self .name ]
71
+ return point [self .name ]
72
+
73
+ def _forward_diff (self , direction , point , cache ):
74
+ return direction [self .name ]
75
+
76
+ def _reverse_diff (self , point , adjoint , gradient , cache ):
77
+ gradient [self .name ] += adjoint
78
+
79
+ class Constant (Expr , namedtuple ("Constant" , ["value" ])):
80
+ def _eval (self , point , cache ):
81
+ cache [id (self )] = self .value
82
+ return self .value
83
+
84
+ def _forward_diff (self , direction , point , cache ):
85
+ return 0
86
+
87
+ def _reverse_diff (self , point , ajoint , gradient , cache ):
88
+ pass
89
+
90
+ class Add (Expr , namedtuple ("Add" , ["expr1" , "expr2" ])):
91
+ def _eval (self , point , cache ):
92
+ if id (self ) not in cache :
93
+ eval1 , eval2 = self .expr1 ._eval , self .expr2 ._eval
94
+ cache [id (self )] = eval1 (point , cache ) + eval2 (point , cache )
95
+ return cache [id (self )]
96
+
97
+ def _forward_diff (self , direction , point , cache ):
98
+ return self .expr1 ._forward_diff (direction , point , cache ) + self .expr2 ._forward_diff (direction , point , cache )
99
+
100
+ def _reverse_diff (self , point , adjoint , gradient , cache ):
101
+ self .expr1 ._reverse_diff (point , adjoint , gradient , cache )
102
+ self .expr2 ._reverse_diff (point , adjoint , gradient , cache )
103
+
104
+ class Subtract (Expr , namedtuple ("Subtract" , ["expr1" , "expr2" ])):
105
+ def _eval (self , point , cache ):
106
+ if id (self ) not in cache :
107
+ eval1 , eval2 = self .expr1 ._eval , self .expr2 ._eval
108
+ cache [id (self )] = eval1 (point , cache ) - eval2 (point , cache )
109
+ return cache [id (self )]
110
+
111
+ def _forward_diff (self , direction , point , cache ):
112
+ return self .expr1 ._forward_diff (direction , point , cache ) - self .expr2 ._forward_diff (direction , point , cache )
113
+
114
+ def _reverse_diff (self , point , adjoint , gradient , cache ):
115
+ self .expr1 ._reverse_diff (point , adjoint , gradient , cache )
116
+ self .expr2 ._reverse_diff (point , - adjoint , gradient , cache )
117
+
118
+ class Multiply (Expr , namedtuple ("Multiply" , ["expr1" , "expr2" ])):
119
+ def _eval (self , point , cache ):
120
+ if id (self ) not in cache :
121
+ eval1 , eval2 = self .expr1 ._eval , self .expr2 ._eval
122
+ cache [id (self )] = eval1 (point , cache ) * eval2 (point , cache )
123
+ return cache [id (self )]
124
+
125
+ def _forward_diff (self , direction , point , cache ):
126
+ return self .expr2 ._eval (point , cache ) * self .expr1 ._forward_diff (direction , point , cache ) + self .expr1 ._eval (point , cache ) * self .expr2 ._forward_diff (direction , point , cache )
127
+
128
+ def _reverse_diff (self , point , adjoint , gradient , cache ):
129
+ self .expr1 ._reverse_diff (point , adjoint * self .expr2 ._eval (point , cache ), gradient , cache )
130
+ self .expr2 ._reverse_diff (point , adjoint * self .expr1 ._eval (point , cache ), gradient , cache )
131
+
132
+ class Divide (Expr , namedtuple ("Divide" , ["expr1" , "expr2" ])):
133
+ def _eval (self , point , cache ):
134
+ if id (self ) not in cache :
135
+ eval1 , eval2 = self .expr1 ._eval , self .expr2 ._eval
136
+ cache [id (self )] = eval1 (point , cache ) / eval2 (point , cache )
137
+ return cache [id (self )]
138
+
139
+ def _forward_diff (self , direction , point , cache ):
140
+ high = cache [id (self .expr1 )]
141
+ low = cache [id (self .expr2 )]
142
+ dhigh = self .expr1 ._forward_diff (direction , point , cache )
143
+ dlow = self .expr2 ._forward_diff (direction , point , cache )
144
+
145
+ return (low * dhigh - high * dlow ) / low ** 2
146
+
147
+ def _reverse_diff (self , point , adjoint , gradient , cache ):
148
+ high = cache [id (self .expr1 )]
149
+ low = cache [id (self .expr2 )]
150
+ self .expr1 ._reverse_diff (point , adjoint / low , gradient , cache )
151
+ self .expr2 ._reverse_diff (point , - adjoint * high / low ** 2 , gradient ,
152
+ cache )
153
+
154
+ class Pow (Expr , namedtuple ("Pow" , ["expr1" , "expr2" ])):
155
+ def _eval (self , point , cache ):
156
+ if id (self ) not in cache :
157
+ eval1 , eval2 = self .expr1 ._eval , self .expr2 ._eval
158
+ cache [id (self )] = eval1 (point , cache ) ** eval2 (point , cache )
159
+ return cache [id (self )]
160
+
161
+ def _forward_diff (self , direction , point , cache ):
162
+ base = cache [id (self .expr1 )]
163
+ exp = cache [id (self .expr2 )]
164
+ dbase = self .expr1 ._forward_diff (direction , point , cache )
165
+ dexp = self .expr2 ._forward_diff (direction , point , cache )
166
+
167
+ if base == 0 : # avoid MathDomainError
168
+ return 0
169
+ else :
170
+ # D_x[f ** g] = D_x[exp(g ln f)] = exp(g ln f) D_x[g ln f]
171
+ # = exp(g ln f) (g'ln f + gf' / f)
172
+ # = f ** (g - 1) (fg' ln f + gf')
173
+
174
+ return (base ** (exp - 1 ) *
175
+ (exp * dbase + base * dexp * math .log (base )))
176
+
177
+
178
+ def _reverse_diff (self , point , adjoint , gradient , cache ):
179
+ base = cache [id (self .expr1 )]
180
+ exp = cache [id (self .expr2 )]
181
+
182
+ self .expr1 ._reverse_diff (point , adjoint * exp * base ** (exp - 1 ),
183
+ gradient , cache )
184
+ self .expr2 ._reverse_diff (point , adjoint * math .log (base ) * base ** exp ,
185
+ gradient , cache )
0 commit comments