Skip to content

Commit 438a4a5

Browse files
author
Christian Schroeder de Witt
committed
Merge branch 'master' of github.com:schroederdewitt/rl_games into master
2 parents bf11583 + bfaee18 commit 438a4a5

28 files changed

+1753
-456
lines changed

algos_tf14/a2c_discrete.py

+139-153
Large diffs are not rendered by default.

algos_tf14/dqnagent.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from common.categorical import CategoricalQ
1212

1313
class DQNAgent:
14-
def __init__(self, sess, base_name, observation_space, action_space, config, logger):
14+
def __init__(self, sess, base_name, observation_space, action_space, config, logger, central_state_space=None):
1515
observation_shape = observation_space.shape
1616
actions_num = action_space.n
1717
self.config = config
@@ -60,12 +60,14 @@ def __init__(self, sess, base_name, observation_space, action_space, config, log
6060
self.v_max = self.config['v_max']
6161
self.delta_z = (self.v_max - self.v_min) / (self.atoms_num - 1)
6262
self.all_z = tf.range(self.v_min, self.v_max + self.delta_z, self.delta_z)
63-
self.categorical = CategoricalQ(self.atoms_num, self.v_min, self.v_max)
63+
self.categorical = CategoricalQ(self.atoms_num, self.v_min, self.v_max)
64+
65+
self.n_agents = self.env.env_info['n_agents']
6466

6567
if not self.is_prioritized:
66-
self.exp_buffer = experience.ReplayBuffer(config['replay_buffer_size'])
68+
self.exp_buffer = experience.ReplayBuffer(config['replay_buffer_size'], observation_space, self.n_agents)
6769
else:
68-
self.exp_buffer = experience.PrioritizedReplayBuffer(config['replay_buffer_size'], config['priority_alpha'])
70+
self.exp_buffer = experience.PrioritizedReplayBuffer(config['replay_buffer_size'], config['priority_alpha'], observation_space, self.n_agents)
6971
self.sample_weights_ph = tf.placeholder(tf.float32, shape= [None,] , name='sample_weights')
7072

7173
self.obs_ph = tf.placeholder(observation_space.dtype, shape=(None,) + self.state_shape , name = 'obs_ph')

algos_tf14/iqlagent.py

+415
Large diffs are not rendered by default.

algos_tf14/model_builder.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self):
1414
self.model_factory.register_builder('continuous_a2c_lstm_logstd', lambda network, **kwargs : models.LSTMModelA2CContinuousLogStd(network))
1515
self.model_factory.register_builder('dqn', lambda network, **kwargs : models.AtariDQN(network))
1616
self.model_factory.register_builder('vdn', lambda network, **kwargs : models.VDN_DQN(network))
17+
self.model_factory.register_builder('iql', lambda network, **kwargs : models.IQL_DQN(network))
1718

1819

1920
self.network_factory = object_factory.ObjectFactory()

algos_tf14/models.py

+120-49
Large diffs are not rendered by default.

algos_tf14/network_builder.py

+174-122
Large diffs are not rendered by default.

algos_tf14/vdnagent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, sess, base_name, observation_space, action_space, config, log
7777
self.n_agents = self.env.env_info['n_agents']
7878

7979
if not self.is_prioritized:
80-
self.exp_buffer = experience.ReplayBufferCentralState(config['replay_buffer_size'])
80+
self.exp_buffer = experience.ReplayBufferCentralState(config['replay_buffer_size'], observation_space, central_state_space, self.n_agents)
8181
else:
8282
raise NotImplementedError("Not implemented! PrioritizedReplayBuffer with CentralState")
8383
#self.exp_buffer = experience.PrioritizedReplayBufferCentralState(config['replay_buffer_size'], config['priority_alpha'])

common/env_configurations.py

+17
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def create_flex(path):
192192

193193
return env
194194

195+
def create_staghunt(name, **kwargs):
196+
from envs.stag_hunt import StagHuntEnv
197+
frames = kwargs.pop('frames', 1)
198+
print(kwargs)
199+
return wrappers.BatchedFrameStack(StagHuntEnv(1, **kwargs), frames, transpose=False, flatten=True)
200+
195201
def create_smac(name, **kwargs):
196202
from envs.smac_env import SMACEnv
197203
frames = kwargs.pop('frames', 1)
@@ -212,6 +218,13 @@ def create_smac_cnn(name, **kwargs):
212218
env = wrappers.BatchedFrameStack(env, frames, transpose=transpose)
213219
return env
214220

221+
def create_staghunt_cnn(name, **kwargs):
222+
from envs.stag_hunt import StagHuntEnv
223+
env = StagHuntEnv(1, **kwargs)
224+
frames = kwargs.pop('frames', 4)
225+
transpose = kwargs.pop('transpose', False)
226+
env = wrappers.BatchedFrameStack(env, frames, transpose=transpose)
227+
return env
215228

216229
configurations = {
217230
'CartPole-v1' : {
@@ -314,6 +327,10 @@ def create_smac_cnn(name, **kwargs):
314327
'env_creator' : lambda **kwargs : create_flex(FLEX_PATH + '/demo/gym/cfg/humanoid_hard.yaml'),
315328
'vecenv_type' : 'ISAAC'
316329
},
330+
'staghunt': {
331+
'env_creator': lambda **kwargs: create_staghunt_cnn(**kwargs),
332+
'vecenv_type': 'RAY_SMAC'
333+
},
317334
'smac' : {
318335
'env_creator' : lambda **kwargs : create_smac(**kwargs),
319336
'vecenv_type' : 'RAY_SMAC'

common/experience.py

+77-41
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,60 @@
55

66

77
class ReplayBufferCentralState(object):
8-
def __init__(self, size):
8+
def __init__(self, size, ob_space, st_space, n_agents):
99
"""Create Replay buffer.
1010
Parameters
1111
----------
1212
size: int
1313
Max number of transitions to store in the buffer. When the buffer
1414
overflows the old memories are dropped.
1515
"""
16-
self._storage = []
16+
17+
self._obses = np.zeros((size,) + (n_agents,) + ob_space.shape, dtype=ob_space.dtype)
18+
self._next_obses = np.zeros((size,) + (n_agents,) + ob_space.shape, dtype=ob_space.dtype)
19+
self._rewards = np.zeros(size)
20+
self._actions = np.zeros((size,) + (n_agents,), dtype=np.int32)
21+
self._dones = np.zeros(size, dtype=np.bool)
22+
self._states = np.zeros((size,) + st_space.shape, dtype=st_space.dtype)
23+
1724
self._maxsize = size
1825
self._next_idx = 0
26+
self._curr_size = 0
1927

2028
def __len__(self):
21-
return len(self._storage)
29+
return self._curr_size
2230

2331
def add(self, obs_t, action, state_t, reward, obs_tp1, done):
24-
data = (obs_t, action, state_t, reward, obs_tp1, done)
32+
# print("CAlled")
33+
self._curr_size = min(self._curr_size + 1, self._maxsize)
34+
35+
self._obses[self._next_idx] = obs_t
36+
self._next_obses[self._next_idx] = obs_tp1
37+
self._rewards[self._next_idx] = reward
38+
self._actions[self._next_idx] = action
39+
self._dones[self._next_idx] = done
40+
self._states[self._next_idx] = state_t
2541

26-
if self._next_idx >= len(self._storage):
27-
self._storage.append(data)
28-
else:
29-
self._storage[self._next_idx] = data
3042
self._next_idx = (self._next_idx + 1) % self._maxsize
43+
# print(self._curr_size)
44+
45+
def _get(self, idx):
46+
return self._obses[idx], self._actions[idx], self._states[idx], self._rewards[idx], self._next_obses[idx], self._dones[idx]
3147

3248
def _encode_sample(self, idxes):
33-
obses_t, actions, states_t, rewards, obses_tp1, dones = [], [], [], [], [], []
49+
batch_size = len(idxes)
50+
obses_t, actions, states_t, rewards, obses_tp1, dones = [None] * batch_size, [None] * batch_size, [None] * batch_size, [None] * batch_size, [None] * batch_size, [None] * batch_size
51+
it = 0
3452
for i in idxes:
35-
data = self._storage[i]
36-
obs_t, action, state_t, reward, obs_tp1, done = data
37-
obses_t.append(np.array(obs_t, copy=False))
38-
actions.append(np.array(action, copy=False))
39-
states_t.append(np.array(state_t, copy=False))
40-
rewards.append(reward)
41-
obses_tp1.append(np.array(obs_tp1, copy=False))
42-
dones.append(done)
53+
data = self._get(i)
54+
obs_t, action, state, reward, obs_tp1, done = data
55+
obses_t[it] = np.array(obs_t, copy=False)
56+
actions[it] = np.array(action, copy=False)
57+
states_t[it] = np.array(state, copy=False)
58+
rewards[it] = reward
59+
obses_tp1[it] = np.array(obs_tp1, copy=False)
60+
dones[it] = done
61+
it = it + 1
4362
return np.array(obses_t), np.array(actions), np.array(states_t), np.array(rewards), np.array(obses_tp1), np.array(dones)
4463

4564
def sample(self, batch_size):
@@ -62,44 +81,61 @@ def sample(self, batch_size):
6281
done_mask[i] = 1 if executing act_batch[i] resulted in
6382
the end of an episode and 0 otherwise.
6483
"""
65-
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
84+
# print(self._curr_size)
85+
idxes = [random.randint(0, self._curr_size - 1) for _ in range(batch_size)]
6686
return self._encode_sample(idxes)
6787

88+
6889
class ReplayBuffer(object):
69-
def __init__(self, size):
90+
def __init__(self, size, ob_space, n_agents):
7091
"""Create Replay buffer.
7192
Parameters
7293
----------
7394
size: int
7495
Max number of transitions to store in the buffer. When the buffer
7596
overflows the old memories are dropped.
7697
"""
77-
self._storage = []
98+
self._obses = np.zeros((size,) + (n_agents,) + ob_space.shape, dtype=ob_space.dtype)
99+
self._next_obses = np.zeros((size,) + (n_agents,) + ob_space.shape, dtype=ob_space.dtype)
100+
self._rewards = np.zeros(size)
101+
self._actions = np.zeros((size,) + (n_agents,), dtype=np.int32)
102+
self._dones = np.zeros(size, dtype=np.bool)
103+
78104
self._maxsize = size
79105
self._next_idx = 0
106+
self._curr_size = 0
80107

81108
def __len__(self):
82-
return len(self._storage)
109+
return self._curr_size
83110

84111
def add(self, obs_t, action, reward, obs_tp1, done):
85-
data = (obs_t, action, reward, obs_tp1, done)
86112

87-
if self._next_idx >= len(self._storage):
88-
self._storage.append(data)
89-
else:
90-
self._storage[self._next_idx] = data
113+
self._curr_size = min(self._curr_size + 1, self._maxsize )
114+
115+
self._obses[self._next_idx] = obs_t
116+
self._next_obses[self._next_idx] = obs_tp1
117+
self._rewards[self._next_idx] = reward
118+
self._actions[self._next_idx] = action
119+
self._dones[self._next_idx] = done
120+
91121
self._next_idx = (self._next_idx + 1) % self._maxsize
92122

123+
def _get(self, idx):
124+
return self._obses[idx], self._actions[idx], self._rewards[idx], self._next_obses[idx], self._dones[idx]
125+
93126
def _encode_sample(self, idxes):
94-
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
127+
batch_size = len(idxes)
128+
obses_t, actions, rewards, obses_tp1, dones = [None] * batch_size, [None] * batch_size, [None] * batch_size, [None] * batch_size, [None] * batch_size
129+
it = 0
95130
for i in idxes:
96-
data = self._storage[i]
131+
data = self._get(i)
97132
obs_t, action, reward, obs_tp1, done = data
98-
obses_t.append(np.array(obs_t, copy=False))
99-
actions.append(np.array(action, copy=False))
100-
rewards.append(reward)
101-
obses_tp1.append(np.array(obs_tp1, copy=False))
102-
dones.append(done)
133+
obses_t[it] = np.array(obs_t, copy=False)
134+
actions[it] = np.array(action, copy=False)
135+
rewards[it] = reward
136+
obses_tp1[it] = np.array(obs_tp1, copy=False)
137+
dones[it] = done
138+
it = it + 1
103139
return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)
104140

105141
def sample(self, batch_size):
@@ -122,12 +158,12 @@ def sample(self, batch_size):
122158
done_mask[i] = 1 if executing act_batch[i] resulted in
123159
the end of an episode and 0 otherwise.
124160
"""
125-
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
161+
idxes = [random.randint(0, self._curr_size - 1) for _ in range(batch_size)]
126162
return self._encode_sample(idxes)
127163

128164

129165
class PrioritizedReplayBuffer(ReplayBuffer):
130-
def __init__(self, size, alpha):
166+
def __init__(self, size, alpha, ob_space, n_agents):
131167
"""Create Prioritized Replay buffer.
132168
Parameters
133169
----------
@@ -141,7 +177,7 @@ def __init__(self, size, alpha):
141177
--------
142178
ReplayBuffer.__init__
143179
"""
144-
super(PrioritizedReplayBuffer, self).__init__(size)
180+
super(PrioritizedReplayBuffer, self).__init__(size, ob_space, n_agents)
145181
assert alpha >= 0
146182
self._alpha = alpha
147183

@@ -162,7 +198,7 @@ def add(self, *args, **kwargs):
162198

163199
def _sample_proportional(self, batch_size):
164200
res = []
165-
p_total = self._it_sum.sum(0, len(self._storage) - 1)
201+
p_total = self._it_sum.sum(0, self._curr_size - 1)
166202
every_range_len = p_total / batch_size
167203
for i in range(batch_size):
168204
mass = random.random() * every_range_len + i * every_range_len
@@ -208,11 +244,11 @@ def sample(self, batch_size, beta):
208244

209245
weights = []
210246
p_min = self._it_min.min() / self._it_sum.sum()
211-
max_weight = (p_min * len(self._storage)) ** (-beta)
247+
max_weight = (p_min * self._curr_size) ** (-beta)
212248

213249
for idx in idxes:
214250
p_sample = self._it_sum[idx] / self._it_sum.sum()
215-
weight = (p_sample * len(self._storage)) ** (-beta)
251+
weight = (p_sample * self._curr_size) ** (-beta)
216252
weights.append(weight / max_weight)
217253
weights = np.array(weights)
218254
encoded_sample = self._encode_sample(idxes)
@@ -234,8 +270,8 @@ def update_priorities(self, idxes, priorities):
234270
assert len(idxes) == len(priorities)
235271
for idx, priority in zip(idxes, priorities):
236272
assert priority > 0
237-
assert 0 <= idx < len(self._storage)
273+
assert 0 <= idx < self._curr_size
238274
self._it_sum[idx] = priority ** self._alpha
239275
self._it_min[idx] = priority ** self._alpha
240276

241-
self._max_priority = max(self._max_priority, priority)
277+
self._max_priority = max(self._max_priority, priority)

common/vecenv.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def reset(self):
3434
class RayWorker:
3535
def __init__(self, config_name, config):
3636
self.env = configurations[config_name]['env_creator'](**config)
37-
self.obs = self.env.reset()
37+
38+
res = self.env.reset()
39+
if isinstance(res, tuple):
40+
self.obs, self.central_state = res
41+
else:
42+
self.obs = res
3843

3944
def step(self, action):
4045
next_state, reward, is_done, info = self.env.step(action)
@@ -136,7 +141,16 @@ def step(self, actions):
136141
newrewards.append(crewards)
137142
newdones.append(cdones)
138143
newinfos.append(cinfos)
139-
return np.concatenate(newobs, axis=0), np.concatenate(newrewards, axis=0), np.concatenate(newdones, axis=0), newinfos
144+
#print("newobs: ", newobs)
145+
#print("newrewards: ", newrewards)
146+
#print("newdones: ", newdones)
147+
#print("newinfos:", newinfos)
148+
#raise Exception()
149+
ro = np.concatenate(newobs, axis=0)
150+
rr = np.concatenate(newrewards, axis=0)
151+
rd = np.concatenate(newdones, axis=0)
152+
ri = newinfos
153+
return ro, rr, rd, ri
140154

141155
def has_action_masks(self):
142156
return True

0 commit comments

Comments
 (0)