Skip to content

Commit

Permalink
v4.1.2 fix(unity): fixed bugs in continuous training and unity traini…
Browse files Browse the repository at this point in the history
…ng. (#34, #25)

1. updated `setup.py`
2. removed redundant packages
3. fixed bugs in unity wrapper
4. fixed bugs in agent models that occurred in continuous-action training tasks
5. fixed bugs in class `MLP`
  • Loading branch information
StepNeverStop committed Jul 29, 2021
1 parent 1bb7729 commit f95a510
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 35 deletions.
2 changes: 1 addition & 1 deletion rls/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '4'
_MINOR_VERSION = '1'
_PATCH_VERSION = '1'
_PATCH_VERSION = '2'

# Example: '0.4.2'
__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
6 changes: 5 additions & 1 deletion rls/algorithms/single/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def __init__(self,
actor_lr=5.0e-4,
critic_lr=1.0e-3,
network_settings={
'actor_continuous': [32, 32],
'actor_continuous': {
'hidden_units': [64, 64],
'condition_sigma': False,
'log_std_bound': [-20, 2]
},
'actor_discrete': [32, 32],
'critic': [32, 32]
},
Expand Down
6 changes: 5 additions & 1 deletion rls/algorithms/single/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def __init__(self,
actor_lr=5.0e-4,
critic_lr=1.0e-3,
network_settings={
'actor_continuous': [32, 32],
'actor_continuous': {
'hidden_units': [64, 64],
'condition_sigma': False,
'log_std_bound': [-20, 2]
},
'actor_discrete': [32, 32],
'critic': [32, 32]
},
Expand Down
6 changes: 5 additions & 1 deletion rls/algorithms/single/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def __init__(self,
lr=5.0e-4,
epoch=5,
network_settings={
'actor_continuous': [32, 32],
'actor_continuous': {
'hidden_units': [32, 32],
'condition_sigma': False,
'log_std_bound': [-20, 2]
},
'actor_discrete': [32, 32]
},
**kwargs):
Expand Down
22 changes: 14 additions & 8 deletions rls/algorithms/single/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self,
network_settings: Dict = {
'share': {
'continuous': {
'condition_sigma': False,
'log_std_bound': [-20, 2],
'share': [32, 32],
'mu': [32, 32],
'v': [32, 32]
Expand All @@ -89,7 +91,11 @@ def __init__(self,
'v': [32, 32]
}
},
'actor_continuous': [32, 32],
'actor_continuous': {
'hidden_units': [64, 64],
'condition_sigma': False,
'log_std_bound': [-20, 2]
},
'actor_discrete': [32, 32],
'critic': [32, 32]
},
Expand All @@ -107,7 +113,7 @@ def __init__(self,
self.kl_reverse = kl_reverse
self.kl_target = kl_target
self.kl_alpha = kl_alpha
self.kl_coef = t.tensor(kl_coef).float()
self.kl_coef = kl_coef
self.extra_coef = extra_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
Expand Down Expand Up @@ -229,13 +235,13 @@ def _train(data, cell_states):
early_step = 0
if self.share_net:
for i in range(self.policy_epoch):
actor_loss, critic_loss, entropy, kl = self.train_share(data, cell_states, self.kl_coef)
actor_loss, critic_loss, entropy, kl = self.train_share(data, cell_states)
if self.use_early_stop and kl > self.kl_stop:
early_step = i
break
else:
for i in range(self.policy_epoch):
actor_loss, entropy, kl = self.train_actor(data, cell_states, self.kl_coef)
actor_loss, entropy, kl = self.train_actor(data, cell_states)
if self.use_early_stop and kl > self.kl_stop:
early_step = i
break
Expand Down Expand Up @@ -283,7 +289,7 @@ def _train(data, cell_states):
})

@iTensor_oNumpy
def train_share(self, BATCH, cell_states, kl_coef):
def train_share(self, BATCH, cell_states):
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_states['obs'])
if self.is_continuous:
mu, log_std, value = self.net(feat)
Expand Down Expand Up @@ -327,7 +333,7 @@ def train_share(self, BATCH, cell_states, kl_coef):
td_square = td_error.square()

if self.use_kl_loss:
kl_loss = kl_coef * kl
kl_loss = self.kl_coef * kl
actor_loss += kl_loss

if self.use_extra_loss:
Expand All @@ -340,7 +346,7 @@ def train_share(self, BATCH, cell_states, kl_coef):
return actor_loss, value_loss, entropy, kl

@iTensor_oNumpy
def train_actor(self, BATCH, cell_states, kl_coef):
def train_actor(self, BATCH, cell_states):
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_states['obs'])
if self.is_continuous:
mu, log_std = self.actor(feat)
Expand Down Expand Up @@ -368,7 +374,7 @@ def train_actor(self, BATCH, cell_states, kl_coef):
actor_loss = -(clipped_surrogate.mean() + self.ent_coef * entropy)

if self.use_kl_loss:
kl_loss = kl_coef * kl
kl_loss = self.kl_coef * kl
actor_loss += kl_loss
if self.use_extra_loss:
extra_loss = self.extra_coef * t.maximum(t.zeros_like(kl), kl - self.kl_cutoff).square()
Expand Down
6 changes: 5 additions & 1 deletion rls/algorithms/single/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def __init__(self,
epsilon=0.2,
critic_lr=1e-3,
network_settings={
'actor_continuous': [32, 32],
'actor_continuous': {
'hidden_units': [64, 64],
'condition_sigma': False,
'log_std_bound': [-20, 2]
},
'actor_discrete': [32, 32],
'critic': [32, 32]
},
Expand Down
2 changes: 1 addition & 1 deletion rls/envs/unity/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class UnityEnv(EnvBase):
def __init__(self,
obs_scale=False,
**kwargs):
self.env = BasicUnityEnvironment(kwargs)
self.env = BasicUnityEnvironment(**kwargs)
if obs_scale:
self.env = ScaleVisualWrapper(env)

Expand Down
23 changes: 11 additions & 12 deletions rls/envs/unity/wrappers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
class BasicUnityEnvironment(object):

def __init__(self,
worker_id=0,
file_name=None,
port=5005,
render=False,
seed=42,
worker_id=0,
timeout_wait=60,
env_copys=1,
env_copys=12,
env_name='3DBall',
real_done=True,
initialize_config={},
Expand All @@ -50,15 +50,14 @@ def __init__(self,
self._n_copys = env_copys
self._real_done = real_done

self._side_channels = self.initialize_all_side_channels(initialize_config)
self._side_channels = self.initialize_all_side_channels(initialize_config, engine_config)
env_kwargs = dict(seed=seed,
worker_id=worker_id,
timeout_wait=timeout_wait,
side_channels=list(self._side_channels.values())) # 注册所有初始化后的通讯频道

if file_nameis not None:
if file_name is not None:
env_dict = load_config('rls/configs/unity/env_dict.yaml')
env_kwargs.update(file_name=file_name
env_kwargs.update(file_name=file_name,
base_port=port,
no_graphics=not render,
additional_args=[
Expand Down Expand Up @@ -109,13 +108,13 @@ def initialize_environment(self):
self.behavior_agents[bn] = len(ds)
self.behavior_ids[bn] = ds.agent_id_to_index

for i, shape in enumerate(spec.observation_shapes):
if len(shape) == 1:
for i, obs_spec in enumerate(spec.observation_specs): # TODO: optimize
if len(obs_spec.shape) == 1:
self.vector_idxs[bn].append(i)
self.vector_dims[bn].append(shape[0])
elif len(shape) == 3:
self.vector_dims[bn].append(obs_spec.shape[0])
elif len(obs_spec.shape) == 3:
self.visual_idxs[bn].append(i)
self.visual_dims[bn].append(list(shape))
self.visual_dims[bn].append(list(obs_spec.shape))
else:
raise ValueError("shape of observation cannot be understood.")
self.vector_info_type[bn] = generate_obs_dataformat(n_copys=self.behavior_agents[bn],
Expand Down Expand Up @@ -253,7 +252,7 @@ def get_obs(self, behavior_names=None, only_obs=False):
all_reward.append(reward[idxs])
# all_info.append(dict(max_step=info_max_step[idxs]))
if only_obs:
return corrected_obs
return all_corrected_obs
else:
rets = []
for corrected_obs, obs, reward in zip(all_corrected_obs, all_obs, all_reward):
Expand Down
2 changes: 1 addition & 1 deletion rls/nn/mlps.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self,
self.add_module(f'{layer}_{i}', Layer_REGISTER[layer](_in, _out))
self.add_module(f'{act_fn}_{i}', Act_REGISTER[act_fn]())

input_dim = outs[-1] or input_dim
input_dim = outs[-1] if len(outs) > 0 else input_dim
if output_shape:
self.add_module('out_layer', Layer_REGISTER[layer](input_dim, output_shape))
if out_act:
Expand Down
10 changes: 5 additions & 5 deletions rls/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def __init__(self, vector_dim, output_shape, network_settings):
ins = network_settings['hidden_units'][-1]
else:
ins = vector_dim
self.mu = MLP(ins, output_shape=output_shape, out_act='tanh')
self.mu = MLP(ins, [], output_shape=output_shape, out_act='tanh')
if self.condition_sigma:
self.log_std = MLP(ins, [], output_shape=output_shape)
else:
self.log_std = -0.5 * t.nn.Parameter(t.ones((1, output_shape)), requires_grad=True)
self.log_std = t.nn.Parameter(-0.5 * t.ones((1, output_shape)))

def forward(self, x):
x = self.share(x)
Expand All @@ -56,7 +56,7 @@ def forward(self, x):
log_std = self.log_std(x)
else:
log_std = self.log_std
log_std.clamp_(self.log_std_min, self.log_std_max)
log_std = log_std.clamp(self.log_std_min, self.log_std_max)
batch_size = mu.shape[0]
if batch_size:
log_std = log_std.repeat(batch_size, 1) # [1, N] => [B, N]
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(self, vector_dim, output_shape, network_settings):
if self.condition_sigma:
self.log_std = MLP(ins, [], output_shape=output_shape)
else:
self.log_std = -0.5 * t.nn.Parameter(t.ones((1, output_shape)), requires_grad=True)
self.log_std = t.nn.Parameter(-0.5 * t.ones((1, output_shape)))

def forward(self, x):
x = self.share(x)
Expand All @@ -349,7 +349,7 @@ def forward(self, x):
batch_size = mu.shape[0]
if batch_size:
log_std = log_std.repeat(batch_size, 1) # [1, N] => [B, N]
log_std.clamp_(self.log_std_min, self.log_std_max)
log_std = log_std.clamp(self.log_std_min, self.log_std_max)
return (mu, log_std, v)


Expand Down
1 change: 0 additions & 1 deletion rls/utils/np_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# encoding: utf-8

import itertools
import scipy.signal
import numpy as np


Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@
'tqdm',
'tensorboard',
'colored_traceback',
'pyglet',
'imageio'
# 'imageio'
],
extras_require=extras,
)

0 comments on commit f95a510

Please sign in to comment.