-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrat.py
351 lines (228 loc) · 10.8 KB
/
rat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import numpy as np
from numpy.linalg import norm
from numpy.random import choice
from tqdm import tqdm
from constants import *
from utilities import get_random_point_in_disc
class PlaceCells:
'''
Set of place cells, whose centers are randomly distributed all over a watermaze.
The activation of each cell (output) depends on the given position (input).
'''
centers = None
current_activation = None
previous_activation = None
positions_over_watermaze = None
activations_over_watermaze = None
def __init__(self):
self.set_random_centers()
self.reset_activations()
self.init_positions_over_watermaze()
self.init_activations_over_watermaze()
def reset_activations(self):
self.current_activation = np.zeros((NB_PLACE_CELLS))
self.previous_activation = np.zeros((NB_PLACE_CELLS))
def set_random_centers(self):
def get_random_center():
return get_random_point_in_disc((X_ORIGIN, Y_ORIGIN), WATERMAZE_RADIUS)
self.centers = np.array(
[get_random_center() for _ in range(NB_PLACE_CELLS)]
)
def init_positions_over_watermaze(self):
step = 0.05
x_coords = np.arange(X_ORIGIN - WATERMAZE_RADIUS, X_ORIGIN + WATERMAZE_RADIUS + step, step)
y_coords = np.arange(Y_ORIGIN - WATERMAZE_RADIUS, Y_ORIGIN + WATERMAZE_RADIUS + step, step)
positions = np.array(np.meshgrid(x_coords, y_coords)).T.reshape(-1, 2)
origin = np.array([X_ORIGIN, Y_ORIGIN])
distances_from_origin = norm(origin - positions, axis = 1)
self.positions_over_watermaze = positions[distances_from_origin <= WATERMAZE_RADIUS]
def init_activations_over_watermaze(self):
self.activations_over_watermaze = np.array([self.activation_at(pos) for pos in self.positions_over_watermaze])
def activation_at(self, position):
return np.exp(- np.square(norm(self.centers - position, axis = 1))
/ (2 * PLACE_CELL_STD * PLACE_CELL_STD))
def update_activations(self, new_position):
np.copyto(self.previous_activation, self.current_activation)
self.current_activation = self.activation_at(new_position)
class Critic:
'''
Critic of the actor-critic model.
It estimates the value function from a weighted sum of the place cell activations,
whose weights can and should be updated after each action taken by the rat.
'''
weights = None
def __init__(self):
self.reset_weights()
def reset_weights(self):
self.weights = np.zeros((NB_PLACE_CELLS))
def estimate_values_over_watermaze(self, place_cells):
return [np.dot(activation, self.weights) for activation in place_cells.activations_over_watermaze]
def update_weights(self, place_cells, reward):
previous_value_estimate = np.dot(place_cells.previous_activation, self.weights)
current_value_estimate = np.dot(place_cells.current_activation, self.weights)
if reward == 1:
current_value_estimate = 0.0
error = reward + (TEMPORAL_DECAY * current_value_estimate) - previous_value_estimate
self.weights += CRITIC_LEARNING_RATE * (error * place_cells.previous_activation)
return error
class Actor:
'''
Actor of the actor-critic model.
It computes action probabilities from a weighted sum of the place cell activations,
whose weights can and should be updated after each action taken by the rat.
'''
weights = None
actions = [
"top_left",
"top",
"top_right",
"right",
"bottom_right",
"bottom",
"bottom_left",
"left"
]
def __init__(self):
self.reset_weights()
def reset_weights(self):
self.weights = np.zeros((len(self.actions), NB_PLACE_CELLS))
def compute_action_probabilities(self, place_cells):
activations = np.dot(self.weights, place_cells.current_activation)
# Retrieve the maximum activations to prevent np.exp from overflowing
# (it does not affect the final probabilities, since they are divided by the sum)
max_activation = np.max(activations)
softmax_activations = np.exp(2.0 * (activations - max_activation))
#return softmax_activations / (softmax_activations.sum() + np.finfo(float).eps)
return softmax_activations / softmax_activations.sum()
def compute_action_probabilities_over_watermaze(self, place_cells):
activations = np.array([np.dot(self.weights, activation) for activation in place_cells.activations_over_watermaze])
max_activation = np.max(activations, axis = 1)
softmax_activations = np.exp(2.0 * (activations.T - max_activation))
return softmax_activations / softmax_activations.sum()
def update_weights(self, place_cells, direction, error):
# Only update the weights of the chosen direction
direction_index = self.actions.index(direction)
self.weights[direction_index, :] += ACTOR_LEARNING_RATE * (error * place_cells.previous_activation)
class Rat:
'''
Rat model based on the actor-critic model.
It comprises place cells, a critic and an actor, which are used to
simulate steps and trials (made up of steps) and to learn weights using RL.
'''
place_cells = None
critic = None
actor = None
current_pos = None
previous_pos = None
starting_pos = [
np.array([X_ORIGIN, Y_ORIGIN + (WATERMAZE_RADIUS * 0.9)]), # Top
np.array([X_ORIGIN, Y_ORIGIN - (WATERMAZE_RADIUS * 0.9)]), # Bottom
np.array([X_ORIGIN + (WATERMAZE_RADIUS * 0.9), Y_ORIGIN]), # Right
np.array([X_ORIGIN - (WATERMAZE_RADIUS * 0.9), Y_ORIGIN]) # Left
]
previous_pos_diff = None
pos_diff_by_direction = {
"top_left": np.array([-0.35355, 0.35355 ]),
"top": np.array([0.0 , 1.0 ]),
"top_right": np.array([0.35355 , 0.35355 ]),
"right": np.array([1.0 , 0.0 ]),
"bottom_right": np.array([0.35355 , -0.35355]),
"bottom": np.array([0.0 , -1.0 ]),
"bottom_left": np.array([-0.35355, -0.35355]),
"left": np.array([-1.0 , 0.0 ])
}
def __init__(self):
self.place_cells = PlaceCells()
self.critic = Critic()
self.actor = Actor()
self.reset_position()
def reset_position(self):
self.current_pos = self.starting_pos[choice(len(self.starting_pos))].copy()
self.previous_pos = None
self.previous_pos_diff = np.array([0.0, 0.0])
# Since the currennt/previous position have been reset,
# the precomputed place cell activations must be reset as well
self.place_cells.reset_activations()
def reset(self):
self.reset_position()
# Reset all the weights
self.critic.reset_weights()
self.actor.reset_weights()
def is_on_plateform(self, watermaze):
return norm(self.current_pos - watermaze.plateform.center) <= watermaze.plateform.radius
def move_to_next_pos(self, watermaze):
# Save the current position as the previous one
self.previous_pos = self.current_pos
# Pick the direction at random according to the action probabilities
probabilities = self.actor.compute_action_probabilities(self.place_cells)
new_direction = choice(self.actor.actions, p = probabilities)
# Compute the position difference
new_pos_diff = self.pos_diff_by_direction[new_direction] * SWIMING_SPEED * TIME_PER_STEP
pos_diff = (ACTION_MOMENTUM_RATIO * new_pos_diff) + ((1.0 - ACTION_MOMENTUM_RATIO) * self.previous_pos_diff)
# If the new position is beyond the watermaze wall, reverse the new direction
if norm(self.current_pos + pos_diff) >= WATERMAZE_RADIUS:
pos_diff *= -1.0
# Update the positions and the place cell activations
self.current_pos += pos_diff
self.previous_pos_diff = pos_diff
self.place_cells.update_activations(self.current_pos)
# Compute the reward for this move
reward = 1 if self.is_on_plateform(watermaze) else 0
return new_direction, reward
def update_weights(self, direction, reward):
error = self.critic.update_weights(self.place_cells, reward)
self.actor.update_weights(self.place_cells, direction, error)
return error
def simulate_one_step(self, watermaze):
direction, reward = self.move_to_next_pos(watermaze)
error = self.update_weights(direction, reward)
return direction, reward, error
def simulate_n_steps(self, watermaze, nb_steps, show_progress_bar = True):
log = {
"direction": [],
"reward": [],
"error": [],
"position": []
}
# Show a progress bar (over the steps) if required
iterator = tqdm(range(nb_steps), desc = "Steps") if show_progress_bar else range(nb_steps)
# Simulate all the steps
for _ in iterator:
direction, reward, error = self.simulate_one_step(watermaze)
log["direction"].append(direction)
log["reward"].append(reward)
log["error"].append(error)
log["position"].append(self.current_pos.copy())
return log
def simulate_one_trial(self, watermaze):
log = {
"direction": [],
"reward": [],
"error": [],
"position": []
}
# Move the rat back to its initial position,
# and iterate for at most TRIAL_TIMEOUT seconds (for the rat)
self.reset_position()
for _ in range(int(np.round(TRIAL_TIMEOUT / TIME_PER_STEP))):
direction, reward, error = self.simulate_one_step(watermaze)
log["direction"].append(direction)
log["reward"].append(reward)
log["error"].append(error)
log["position"].append(self.current_pos.copy())
# Stop trial if the rat reaches the plateform (reward = 1)
if reward == 1:
break
# Log some additional data at the end of the step
log["critic_values"] = self.critic.estimate_values_over_watermaze(self.place_cells)
log["action_probabilities"] = self.actor.compute_action_probabilities_over_watermaze(self.place_cells)
return log
def simulate_n_trials(self, watermaze, nb_trials, show_progress_bar = True):
logs = []
# Show a progress bar (over the trials) if required
iterator = tqdm(range(nb_trials), desc = "Trials") if show_progress_bar else range(nb_trials)
# Simulate all the trials
for _ in iterator:
log = self.simulate_one_trial(watermaze)
logs.append(log)
return logs