Skip to content

Commit 554258c

Browse files
Merge pull request #2 from schroederdewitt/vdn_s
bug fix for state shape
2 parents d804a25 + e2a0833 commit 554258c

File tree

7 files changed

+10
-174
lines changed

7 files changed

+10
-174
lines changed

algos_tf14/vdnagent.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def __init__(self, sess, base_name, observation_space, action_space, config, log
7070
self.atoms_num = self.config['atoms_num']
7171
assert self.atoms_num == 1
7272

73-
self.state_shape = (self.env.env_info['state_shape'],)
73+
if central_state_space is not None:
74+
self.state_shape = central_state_space.shape
75+
else:
76+
raise NotImplementedError("central_state_space input to VDN is NONE!")
7477
self.n_agents = self.env.env_info['n_agents']
7578

7679
if not self.is_prioritized:
@@ -225,6 +228,7 @@ def play_steps(self, steps, epsilon=0.0):
225228
# Same reward, done for all agents
226229
reward = reward[0]
227230
is_done = all(is_done)
231+
state = state[0]
228232

229233
self.step_count += 1
230234
self.total_reward += reward

configs/vdn_3s5z_vs_3s6z.yaml

-84
This file was deleted.

configs/vdn_3s_vs_5z.yaml

-84
This file was deleted.

configs/whirl_baselines/vdn_3s5z_vs_3s6z.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,5 @@ params:
8181
name: 3s5z_vs_3s6z
8282
frames: 4
8383
transpose: True
84-
random_invalid_step: False
84+
random_invalid_step: False
85+
use_central_state: True

configs/whirl_baselines/vdn_3s_vs_5z.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,5 @@ params:
8181
name: 3s_vs_5z
8282
frames: 4
8383
transpose: True
84-
random_invalid_step: False
84+
random_invalid_step: False
85+
use_central_state: True

configs/whirl_baselines/vdn_MMM2.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,4 @@ params:
8282
frames: 4
8383
transpose: True
8484
random_invalid_step: False
85+
use_central_state: True

envs/smac_env.py

-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ def _preproc_actions(self, actions):
4949
def get_state(self):
5050
return self.env.get_state()
5151

52-
def get_state(self):
53-
return self.env.get_state()
54-
5552
def step(self, actions):
5653
fixed_rewards = None
5754

0 commit comments

Comments
 (0)