-
Notifications
You must be signed in to change notification settings - Fork 10
/
utils.py
194 lines (166 loc) · 7.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import tensorflow as tf
import numpy as np
import random
# Borrowed from Berkeley's CS 294 Hw 3
# https://github.com/berkeleydeeprlcourse/homework/tree/master/
def sample_n_unique(sampling_f, n):
"""Helper function. Given a function `sampling_f` that returns
comparable objects, sample n such unique objects.
"""
res = []
while len(res) < n:
candidate = sampling_f()
if candidate not in res:
res.append(candidate)
return res
class ReplayBuffer(object):
def __init__(self, size, frame_history_len):
"""This is a memory efficient implementation of the replay buffer.
The sepecific memory optimizations use here are:
- only store each frame once rather than k times
even if every observation normally consists of k last frames
- store frames as np.uint8 (actually it is most time-performance
to cast them back to float32 on GPU to minimize memory transfer
time)
- store frame_t and frame_(t+1) in the same buffer.
For the typical use case in Atari Deep RL buffer with 1M frames the total
memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes
Warning! Assumes that returning frame of zeros at the beginning
of the episode, when there is less frames than `frame_history_len`,
is acceptable.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
frame_history_len: int
Number of memories to be retried for each observation.
"""
self.size = size
self.frame_history_len = frame_history_len
self.next_idx = 0
self.num_in_buffer = 0
self.obs = None
self.action = None
self.reward = None
self.done = None
def can_sample(self, batch_size):
"""Returns true if `batch_size` different transitions can be sampled from the buffer."""
return batch_size + 1 <= self.num_in_buffer
def _encode_sample(self, idxes):
obs_batch = np.concatenate([self._encode_observation(idx)[None] for idx in idxes], 0)
act_batch = self.action[idxes]
rew_batch = self.reward[idxes]
next_obs_batch = np.concatenate([self._encode_observation(idx + 1)[None] for idx in idxes], 0)
done_mask = np.array([1.0 if self.done[idx] else 0.0 for idx in idxes], dtype=np.float32)
return obs_batch, act_batch, rew_batch, next_obs_batch, done_mask
def sample(self, batch_size):
"""Sample `batch_size` different transitions.
i-th sample transition is the following:
when observing `obs_batch[i]`, action `act_batch[i]` was taken,
after which reward `rew_batch[i]` was received and subsequent
observation next_obs_batch[i] was observed, unless the epsiode
was done which is represented by `done_mask[i]` which is equal
to 1 if episode has ended as a result of that action.
Parameters
----------
batch_size: int
How many transitions to sample.
Returns
-------
obs_batch: np.array
Array of shape
(batch_size, img_h, img_w, img_c * frame_history_len)
and dtype np.uint8
act_batch: np.array
Array of shape (batch_size,) and dtype np.int32
rew_batch: np.array
Array of shape (batch_size,) and dtype np.float32
next_obs_batch: np.array
Array of shape
(batch_size, img_h, img_w, img_c * frame_history_len)
and dtype np.uint8
done_mask: np.array
Array of shape (batch_size,) and dtype np.float32
"""
assert self.can_sample(batch_size)
idxes = sample_n_unique(lambda: random.randint(0, self.num_in_buffer - 2), batch_size)
return self._encode_sample(idxes)
def encode_recent_observation(self):
"""Return the most recent `frame_history_len` frames.
Returns
-------
observation: np.array
Array of shape (img_h, img_w, img_c * frame_history_len)
and dtype np.uint8, where observation[:, :, i*img_c:(i+1)*img_c]
encodes frame at time `t - frame_history_len + i`
"""
assert self.num_in_buffer > 0
return self._encode_observation((self.next_idx - 1) % self.size)
def _encode_observation(self, idx):
end_idx = idx + 1 # make noninclusive
start_idx = end_idx - self.frame_history_len
# this checks if we are using low-dimensional observations, such as RAM
# state, in which case we just directly return the latest RAM.
if len(self.obs.shape) == 2:
return self.obs[end_idx-1]
# if there weren't enough frames ever in the buffer for context
if start_idx < 0 and self.num_in_buffer != self.size:
start_idx = 0
for idx in range(start_idx, end_idx - 1):
if self.done[idx % self.size]:
start_idx = idx + 1
missing_context = self.frame_history_len - (end_idx - start_idx)
# if zero padding is needed for missing context
# or we are on the boundry of the buffer
if start_idx < 0 or missing_context > 0:
frames = [np.zeros_like(self.obs[0]) for _ in range(missing_context)]
for idx in range(start_idx, end_idx):
frames.append(self.obs[idx % self.size])
return np.concatenate(frames, 2)
else:
# this optimization has potential to saves about 30% compute time \o/
img_h, img_w = self.obs.shape[1], self.obs.shape[2]
return self.obs[start_idx:end_idx].transpose(1, 2, 0, 3).reshape(img_h, img_w, -1)
def store_frame(self, frame):
"""Store a single frame in the buffer at the next available index, overwriting
old frames if necessary.
Parameters
----------
frame: np.array
Array of shape (img_h, img_w, img_c) and dtype np.uint8
the frame to be stored
Returns
-------
idx: int
Index at which the frame is stored. To be used for `store_effect` later.
"""
if self.obs is None:
self.obs = np.empty([self.size] + list(frame.shape), dtype=np.uint8)
self.action = np.empty([self.size], dtype=np.int32)
self.reward = np.empty([self.size], dtype=np.float32)
self.done = np.empty([self.size], dtype=np.bool)
self.obs[self.next_idx] = frame
ret = self.next_idx
self.next_idx = (self.next_idx + 1) % self.size
self.num_in_buffer = min(self.size, self.num_in_buffer + 1)
return ret
def store_effect(self, idx, action, reward, done):
"""Store effects of action taken after obeserving frame stored
at index idx. The reason `store_frame` and `store_effect` is broken
up into two functions is so that once can call `encode_recent_observation`
in between.
Paramters
---------
idx: int
Index in buffer of recently observed frame (returned by `store_frame`).
action: int
Action that was performed upon observing this frame.
reward: float
Reward that was received when the actions was performed.
done: bool
True if episode was finished after performing that action.
"""
self.action[idx] = action
self.reward[idx] = reward
self.done[idx] = done