Skip to content

Commit

Permalink
fix&refactor(unity): optimize unity wrapper, fix ddpg (#25, #34)
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Jan 6, 2021
1 parent fb9e5eb commit 7dad01d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 53 deletions.
2 changes: 1 addition & 1 deletion rls/algos/single/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self,

ployak=0.995,
noise_type='ou',
use_target_action_noise: False,
use_target_action_noise=False,
gaussian_noise_sigma=0.2,
gaussian_noise_bound=0.2,
actor_lr=5.0e-4,
Expand Down
2 changes: 1 addition & 1 deletion rls/algos/single/dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self,

actor_lr=5.0e-4,
critic_lr=1.0e-3,
use_target_action_noise: False,
use_target_action_noise=False,
gaussian_noise_sigma=0.2,
gaussian_noise_bound=0.2,
discrete_tau=1.0,
Expand Down
2 changes: 1 addition & 1 deletion rls/algos/single/pd_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self,
envspec,

ployak=0.995,
use_target_action_noise: False,
use_target_action_noise=False,
gaussian_noise_sigma=0.2,
gaussian_noise_bound=0.2,
actor_lr=5.0e-4,
Expand Down
96 changes: 46 additions & 50 deletions rls/envs/unity_wrapper/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def reset(self, **kwargs):
for k, v in kwargs.get('reset_config', {}).items():
self._side_channels['float_properties_channel'].set_float_parameter(k, v)
self.env.reset()
obs = self.get_obs()
return obs if self.is_multi_agents else obs[self.first_bn]
return self.get_obs()

def step(self, actions, **kwargs):
'''
Expand Down Expand Up @@ -97,8 +96,7 @@ def step(self, actions, **kwargs):
self.env.set_actions(self.first_bn, self.empty_actiontuples[self.first_bn])

self.env.step()
obs = self.get_obs()
return obs if self.is_multi_agents else obs[self.first_bn]
return self.get_obs()

def initialize_environment(self):
'''
Expand Down Expand Up @@ -208,60 +206,58 @@ def _get_real_agent_numbers_and_ids(self):

return behavior_agents, behavior_ids

def get_obs(self):
def get_obs(self, behavior_names=None):
'''
解析环境反馈的信息,将反馈信息分为四部分:向量、图像、奖励、done信号
'''
behavior_names = behavior_names or self.behavior_names
rets = {}
for bn in self.behavior_names:
rets[bn] = self.coordinate_information(bn)
return rets

def coordinate_information(self, bn):
'''
TODO: Annotation
'''
n = self.behavior_agents[bn]
ids = self.behavior_ids[bn]
ps = []
d, t = self.env.get_steps(bn)
if len(t):
ps.append(t)

if len(d) != 0 and len(d) != n:
raise ValueError(f'agents number error. Expected 0 or {n}, received {len(d)}')

# some of environments done, but some of not
while len(d) != n:
self.env.step()
for bn in behavior_names:
n = self.behavior_agents[bn]
ids = self.behavior_ids[bn]
ps = []
d, t = self.env.get_steps(bn)
if len(t):
ps.append(t)

corrected_obs, reward = d.obs, d.reward
obs = deepcopy(corrected_obs) # corrected_obs应包含正确的用于决策动作的下一状态
done = np.full(n, False)
info = dict(max_step=np.full(n, False), real_done=np.full(n, False))

for t in ps: # TODO: 有待优化
_ids = np.asarray([ids[i] for i in t.agent_id], dtype=int)
info['max_step'][_ids] = t.interrupted # 因为达到episode最大步数而终止的
info['real_done'][_ids[~t.interrupted]] = True # 去掉因为max_step而done的,只记录因为失败/成功而done的
reward[_ids] = t.reward
done[_ids] = True
# zip: vector, visual, ...
for _obs, _tobs in zip(obs, t.obs):
_obs[_ids] = _tobs

return SingleModelInformation(
corrected_obs=ModelObservations(vector=self.vector_info_type[bn](*[corrected_obs[vi] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[corrected_obs[vi] for vi in self.visual_idxs[bn]])),
obs=ModelObservations(vector=self.vector_info_type[bn](*[obs[vi] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[obs[vi] for vi in self.visual_idxs[bn]])),
reward=np.asarray(reward),
done=np.asarray(done),
info=info
)
if len(d) != 0 and len(d) != n:
raise ValueError(f'agents number error. Expected 0 or {n}, received {len(d)}')

# some of environments done, but some of not
while len(d) != n:
self.env.step()
d, t = self.env.get_steps(bn)
if len(t):
ps.append(t)

corrected_obs, reward = d.obs, d.reward
obs = deepcopy(corrected_obs) # corrected_obs应包含正确的用于决策动作的下一状态
done = np.full(n, False)
info = dict(max_step=np.full(n, False), real_done=np.full(n, False))

for t in ps: # TODO: 有待优化
_ids = np.asarray([ids[i] for i in t.agent_id], dtype=int)
info['max_step'][_ids] = t.interrupted # 因为达到episode最大步数而终止的
info['real_done'][_ids[~t.interrupted]] = True # 去掉因为max_step而done的,只记录因为失败/成功而done的
reward[_ids] = t.reward
done[_ids] = True
# zip: vector, visual, ...
for _obs, _tobs in zip(obs, t.obs):
_obs[_ids] = _tobs

rets[bn] = SingleModelInformation(
corrected_obs=ModelObservations(vector=self.vector_info_type[bn](*[corrected_obs[vi] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[corrected_obs[vi] for vi in self.visual_idxs[bn]])),
obs=ModelObservations(vector=self.vector_info_type[bn](*[obs[vi] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[obs[vi] for vi in self.visual_idxs[bn]])),
reward=np.asarray(reward),
done=np.asarray(done),
info=info
)
if self.is_multi_agents:
return rets
else:
return rets[self.first_bn]

def random_action(self):
'''
Expand Down

0 comments on commit 7dad01d

Please sign in to comment.