5
5
6
6
7
7
class ReplayBufferCentralState (object ):
8
- def __init__ (self , size ):
8
+ def __init__ (self , size , ob_space , st_space , n_agents ):
9
9
"""Create Replay buffer.
10
10
Parameters
11
11
----------
12
12
size: int
13
13
Max number of transitions to store in the buffer. When the buffer
14
14
overflows the old memories are dropped.
15
15
"""
16
- self ._storage = []
16
+
17
+ self ._obses = np .zeros ((size ,) + (n_agents ,) + ob_space .shape , dtype = ob_space .dtype )
18
+ self ._next_obses = np .zeros ((size ,) + (n_agents ,) + ob_space .shape , dtype = ob_space .dtype )
19
+ self ._rewards = np .zeros (size )
20
+ self ._actions = np .zeros ((size ,) + (n_agents ,), dtype = np .int32 )
21
+ self ._dones = np .zeros (size , dtype = np .bool )
22
+ self ._states = np .zeros ((size ,) + st_space .shape , dtype = st_space .dtype )
23
+
17
24
self ._maxsize = size
18
25
self ._next_idx = 0
26
+ self ._curr_size = 0
19
27
20
28
def __len__ (self ):
21
- return len ( self ._storage )
29
+ return self ._curr_size
22
30
23
31
def add (self , obs_t , action , state_t , reward , obs_tp1 , done ):
24
- data = (obs_t , action , state_t , reward , obs_tp1 , done )
32
+ # print("CAlled")
33
+ self ._curr_size = min (self ._curr_size + 1 , self ._maxsize )
34
+
35
+ self ._obses [self ._next_idx ] = obs_t
36
+ self ._next_obses [self ._next_idx ] = obs_tp1
37
+ self ._rewards [self ._next_idx ] = reward
38
+ self ._actions [self ._next_idx ] = action
39
+ self ._dones [self ._next_idx ] = done
40
+ self ._states [self ._next_idx ] = state_t
25
41
26
- if self ._next_idx >= len (self ._storage ):
27
- self ._storage .append (data )
28
- else :
29
- self ._storage [self ._next_idx ] = data
30
42
self ._next_idx = (self ._next_idx + 1 ) % self ._maxsize
43
+ # print(self._curr_size)
44
+
45
+ def _get (self , idx ):
46
+ return self ._obses [idx ], self ._actions [idx ], self ._states [idx ], self ._rewards [idx ], self ._next_obses [idx ], self ._dones [idx ]
31
47
32
48
def _encode_sample (self , idxes ):
33
- obses_t , actions , states_t , rewards , obses_tp1 , dones = [], [], [], [], [], []
49
+ batch_size = len (idxes )
50
+ obses_t , actions , states_t , rewards , obses_tp1 , dones = [None ] * batch_size , [None ] * batch_size , [None ] * batch_size , [None ] * batch_size , [None ] * batch_size , [None ] * batch_size
51
+ it = 0
34
52
for i in idxes :
35
- data = self ._storage [i ]
36
- obs_t , action , state_t , reward , obs_tp1 , done = data
37
- obses_t .append (np .array (obs_t , copy = False ))
38
- actions .append (np .array (action , copy = False ))
39
- states_t .append (np .array (state_t , copy = False ))
40
- rewards .append (reward )
41
- obses_tp1 .append (np .array (obs_tp1 , copy = False ))
42
- dones .append (done )
53
+ data = self ._get (i )
54
+ obs_t , action , state , reward , obs_tp1 , done = data
55
+ obses_t [it ] = np .array (obs_t , copy = False )
56
+ actions [it ] = np .array (action , copy = False )
57
+ states_t [it ] = np .array (state , copy = False )
58
+ rewards [it ] = reward
59
+ obses_tp1 [it ] = np .array (obs_tp1 , copy = False )
60
+ dones [it ] = done
61
+ it = it + 1
43
62
return np .array (obses_t ), np .array (actions ), np .array (states_t ), np .array (rewards ), np .array (obses_tp1 ), np .array (dones )
44
63
45
64
def sample (self , batch_size ):
@@ -62,44 +81,61 @@ def sample(self, batch_size):
62
81
done_mask[i] = 1 if executing act_batch[i] resulted in
63
82
the end of an episode and 0 otherwise.
64
83
"""
65
- idxes = [random .randint (0 , len (self ._storage ) - 1 ) for _ in range (batch_size )]
84
+ # print(self._curr_size)
85
+ idxes = [random .randint (0 , self ._curr_size - 1 ) for _ in range (batch_size )]
66
86
return self ._encode_sample (idxes )
67
87
88
+
68
89
class ReplayBuffer (object ):
69
- def __init__ (self , size ):
90
+ def __init__ (self , size , ob_space , n_agents ):
70
91
"""Create Replay buffer.
71
92
Parameters
72
93
----------
73
94
size: int
74
95
Max number of transitions to store in the buffer. When the buffer
75
96
overflows the old memories are dropped.
76
97
"""
77
- self ._storage = []
98
+ self ._obses = np .zeros ((size ,) + (n_agents ,) + ob_space .shape , dtype = ob_space .dtype )
99
+ self ._next_obses = np .zeros ((size ,) + (n_agents ,) + ob_space .shape , dtype = ob_space .dtype )
100
+ self ._rewards = np .zeros (size )
101
+ self ._actions = np .zeros ((size ,) + (n_agents ,), dtype = np .int32 )
102
+ self ._dones = np .zeros (size , dtype = np .bool )
103
+
78
104
self ._maxsize = size
79
105
self ._next_idx = 0
106
+ self ._curr_size = 0
80
107
81
108
def __len__ (self ):
82
- return len ( self ._storage )
109
+ return self ._curr_size
83
110
84
111
def add (self , obs_t , action , reward , obs_tp1 , done ):
85
- data = (obs_t , action , reward , obs_tp1 , done )
86
112
87
- if self ._next_idx >= len (self ._storage ):
88
- self ._storage .append (data )
89
- else :
90
- self ._storage [self ._next_idx ] = data
113
+ self ._curr_size = min (self ._curr_size + 1 , self ._maxsize )
114
+
115
+ self ._obses [self ._next_idx ] = obs_t
116
+ self ._next_obses [self ._next_idx ] = obs_tp1
117
+ self ._rewards [self ._next_idx ] = reward
118
+ self ._actions [self ._next_idx ] = action
119
+ self ._dones [self ._next_idx ] = done
120
+
91
121
self ._next_idx = (self ._next_idx + 1 ) % self ._maxsize
92
122
123
+ def _get (self , idx ):
124
+ return self ._obses [idx ], self ._actions [idx ], self ._rewards [idx ], self ._next_obses [idx ], self ._dones [idx ]
125
+
93
126
def _encode_sample (self , idxes ):
94
- obses_t , actions , rewards , obses_tp1 , dones = [], [], [], [], []
127
+ batch_size = len (idxes )
128
+ obses_t , actions , rewards , obses_tp1 , dones = [None ] * batch_size , [None ] * batch_size , [None ] * batch_size , [None ] * batch_size , [None ] * batch_size
129
+ it = 0
95
130
for i in idxes :
96
- data = self ._storage [ i ]
131
+ data = self ._get ( i )
97
132
obs_t , action , reward , obs_tp1 , done = data
98
- obses_t .append (np .array (obs_t , copy = False ))
99
- actions .append (np .array (action , copy = False ))
100
- rewards .append (reward )
101
- obses_tp1 .append (np .array (obs_tp1 , copy = False ))
102
- dones .append (done )
133
+ obses_t [it ] = np .array (obs_t , copy = False )
134
+ actions [it ] = np .array (action , copy = False )
135
+ rewards [it ] = reward
136
+ obses_tp1 [it ] = np .array (obs_tp1 , copy = False )
137
+ dones [it ] = done
138
+ it = it + 1
103
139
return np .array (obses_t ), np .array (actions ), np .array (rewards ), np .array (obses_tp1 ), np .array (dones )
104
140
105
141
def sample (self , batch_size ):
@@ -122,12 +158,12 @@ def sample(self, batch_size):
122
158
done_mask[i] = 1 if executing act_batch[i] resulted in
123
159
the end of an episode and 0 otherwise.
124
160
"""
125
- idxes = [random .randint (0 , len ( self ._storage ) - 1 ) for _ in range (batch_size )]
161
+ idxes = [random .randint (0 , self ._curr_size - 1 ) for _ in range (batch_size )]
126
162
return self ._encode_sample (idxes )
127
163
128
164
129
165
class PrioritizedReplayBuffer (ReplayBuffer ):
130
- def __init__ (self , size , alpha ):
166
+ def __init__ (self , size , alpha , ob_space , n_agents ):
131
167
"""Create Prioritized Replay buffer.
132
168
Parameters
133
169
----------
@@ -141,7 +177,7 @@ def __init__(self, size, alpha):
141
177
--------
142
178
ReplayBuffer.__init__
143
179
"""
144
- super (PrioritizedReplayBuffer , self ).__init__ (size )
180
+ super (PrioritizedReplayBuffer , self ).__init__ (size , ob_space , n_agents )
145
181
assert alpha >= 0
146
182
self ._alpha = alpha
147
183
@@ -162,7 +198,7 @@ def add(self, *args, **kwargs):
162
198
163
199
def _sample_proportional (self , batch_size ):
164
200
res = []
165
- p_total = self ._it_sum .sum (0 , len ( self ._storage ) - 1 )
201
+ p_total = self ._it_sum .sum (0 , self ._curr_size - 1 )
166
202
every_range_len = p_total / batch_size
167
203
for i in range (batch_size ):
168
204
mass = random .random () * every_range_len + i * every_range_len
@@ -208,11 +244,11 @@ def sample(self, batch_size, beta):
208
244
209
245
weights = []
210
246
p_min = self ._it_min .min () / self ._it_sum .sum ()
211
- max_weight = (p_min * len ( self ._storage ) ) ** (- beta )
247
+ max_weight = (p_min * self ._curr_size ) ** (- beta )
212
248
213
249
for idx in idxes :
214
250
p_sample = self ._it_sum [idx ] / self ._it_sum .sum ()
215
- weight = (p_sample * len ( self ._storage ) ) ** (- beta )
251
+ weight = (p_sample * self ._curr_size ) ** (- beta )
216
252
weights .append (weight / max_weight )
217
253
weights = np .array (weights )
218
254
encoded_sample = self ._encode_sample (idxes )
@@ -234,8 +270,8 @@ def update_priorities(self, idxes, priorities):
234
270
assert len (idxes ) == len (priorities )
235
271
for idx , priority in zip (idxes , priorities ):
236
272
assert priority > 0
237
- assert 0 <= idx < len ( self ._storage )
273
+ assert 0 <= idx < self ._curr_size
238
274
self ._it_sum [idx ] = priority ** self ._alpha
239
275
self ._it_min [idx ] = priority ** self ._alpha
240
276
241
- self ._max_priority = max (self ._max_priority , priority )
277
+ self ._max_priority = max (self ._max_priority , priority )
0 commit comments