-
Notifications
You must be signed in to change notification settings - Fork 49
/
utils.py
35 lines (27 loc) · 1.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
class ReplayBuffer:
def __init__(self, max_size=5e5):
self.buffer = []
self.max_size = int(max_size)
self.size = 0
def add(self, transition):
assert len(transition) == 7, "transition must have length = 7"
# transiton is tuple of (state, action, reward, next_state, goal, gamma, done)
self.buffer.append(transition)
self.size +=1
def sample(self, batch_size):
# delete 1/5th of the buffer when full
if self.size > self.max_size:
del self.buffer[0:int(self.size/5)]
self.size = len(self.buffer)
indexes = np.random.randint(0, len(self.buffer), size=batch_size)
states, actions, rewards, next_states, goals, gamma, dones = [], [], [], [], [], [], []
for i in indexes:
states.append(np.array(self.buffer[i][0], copy=False))
actions.append(np.array(self.buffer[i][1], copy=False))
rewards.append(np.array(self.buffer[i][2], copy=False))
next_states.append(np.array(self.buffer[i][3], copy=False))
goals.append(np.array(self.buffer[i][4], copy=False))
gamma.append(np.array(self.buffer[i][5], copy=False))
dones.append(np.array(self.buffer[i][6], copy=False))
return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(goals), np.array(gamma), np.array(dones)