Skip to content

Commit

Permalink
refactor(unity): optimize unity wrapper (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Jan 6, 2021
1 parent 7dad01d commit 719ccf4
Showing 1 changed file with 15 additions and 34 deletions.
49 changes: 15 additions & 34 deletions rls/envs/unity_wrapper/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def initialize_environment(self):
self.first_bn = self.behavior_names[0]
self.first_fbn = self.first_bn.replace('?', '_')

self.behavior_agents, self.behavior_ids = self._get_real_agent_numbers_and_ids() # 得到每个环境控制几个智能体

self.behavior_agents = defaultdict(int)
self.behavior_ids = defaultdict(dict)
self.vector_idxs = defaultdict(list)
self.vector_dims = defaultdict(list)
self.visual_idxs = defaultdict(list)
Expand All @@ -122,7 +122,12 @@ def initialize_environment(self):
self.vector_info_type = {}
self.visual_info_type = {}

self.env.reset()
for bn, spec in self.env.behavior_specs.items():
d, t = self.env.get_steps(bn)
self.behavior_agents[bn] = len(d)
self.behavior_ids[bn] = d.agent_id_to_index

for i, shape in enumerate(spec.observation_shapes):
if len(shape) == 1:
self.vector_idxs[bn].append(i)
Expand Down Expand Up @@ -183,29 +188,6 @@ def EnvSpec(self):
n_agents=self.behavior_agents[self.first_bn]
)

def _get_real_agent_numbers_and_ids(self):
'''获取环境中真实的智能体数量和对应的id'''
self.env.reset()
behavior_agents = defaultdict(int)
behavior_ids = defaultdict(lambda: np.empty(0))
# 10 step
for _ in range(10):
for bn in self.behavior_names:
d, t = self.env.get_steps(bn)
# TODO: 检查t是否影响
if len(d) > len(behavior_ids[bn]):
behavior_agents[bn] = len(d)
behavior_ids[bn] = d.agent_id
self.env.set_actions(bn, self.env.behavior_specs[bn].action_spec.random_action(n_agents=len(d)))
self.env.step()

for bn in self.behavior_names:
behavior_ids[bn] = {_id: _idx for _id, _idx in zip(behavior_ids[bn], range(len(behavior_ids[bn])))}

self.env.reset()

return behavior_agents, behavior_ids

def get_obs(self, behavior_names=None):
'''
解析环境反馈的信息,将反馈信息分为四部分:向量、图像、奖励、done信号
Expand All @@ -216,20 +198,19 @@ def get_obs(self, behavior_names=None):
n = self.behavior_agents[bn]
ids = self.behavior_ids[bn]
ps = []
d, t = self.env.get_steps(bn)
if len(t):
ps.append(t)

if len(d) != 0 and len(d) != n:
raise ValueError(f'agents number error. Expected 0 or {n}, received {len(d)}')

# some of environments done, but some of not
while len(d) != n:
self.env.step()
while True:
d, t = self.env.get_steps(bn)
if len(t):
ps.append(t)

if len(d) == n:
break
elif len(d) == 0:
self.env.step() # some of environments done, but some of not
else:
raise ValueError(f'agents number error. Expected 0 or {n}, received {len(d)}')

corrected_obs, reward = d.obs, d.reward
obs = deepcopy(corrected_obs) # corrected_obs应包含正确的用于决策动作的下一状态
done = np.full(n, False)
Expand Down

0 comments on commit 719ccf4

Please sign in to comment.