Skip to content

Commit

Permalink
fix(unity): fixed visual input training using Unity3D (#34, #25)
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Aug 27, 2021
1 parent 251ad97 commit 39d60c3
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion rls/configs/algorithms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ sarl_off_policy: &sarl_off_policy
<<: *sarl_policy
n_time_step: 4 # n-step or rnn length
epochs: 1 # train multiple times per agent step
batch_size: 256
batch_size: 16
buffer_size: 100000
use_priority: false

Expand Down
2 changes: 1 addition & 1 deletion rls/configs/unity/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ render: false
worker_id: 0
timeout_wait: 60
env_name: 3DBall
real_done: true
real_done: false

obs_scale: false # change [0, 1] to [0, 255]

Expand Down
5 changes: 3 additions & 2 deletions rls/envs/unity/wrappers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ def get_obs(self, behavior_names=None, only_obs=False):
if len(self._vector_idxs[bn]) > 0:
_obs.update(
vector={f'vector_{i}': obs[vi] for i, vi in enumerate(self._vector_idxs[bn])})
elif len(self._visual_idxs[bn]) > 0:
_obs.updata(

if len(self._visual_idxs[bn]) > 0:
_obs.update(
visual={f'visual_{i}': obs[vi] for i, vi in enumerate(self._visual_idxs[bn])})
all_obs[bn] = _obs
all_reward[bn] = reward
Expand Down
6 changes: 4 additions & 2 deletions rls/nn/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def __init__(self, visual_dim=[], visual_feature=128, network_type='nature'):

def forward(self, *visual_inputs):
# h, w, c => c, h, w
visual_inputs = [vi.swapaxes(-1, -3).swapaxes(-1, -2)
batch = visual_inputs[0].shape[:-3]
batch_prod = np.prod(batch)
visual_inputs = [vi.view((-1,)+vi.shape[-3:]).swapaxes(-1, -3).swapaxes(-1, -2)
for vi in visual_inputs]
output = []
for dense_net, visual_s in zip(self.dense_nets, visual_inputs):
output.append(
dense_net(visual_s)
dense_net(visual_s).view(batch+(-1,))
)
output = t.cat(output, -1)
return output
Expand Down
6 changes: 3 additions & 3 deletions rls/nn/represent_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def forward(self, obs, cell_state=None):
feat: [T, B, *] or [B, *]
'''
feat_list = []

if self.obs_spec.has_vector_observation:
feat_list.append(self.vector_net(*obs.vector.values()))
elif self.obs_spec.has_visual_observation:

if self.obs_spec.has_visual_observation:
feat_list.append(self.visual_net(*obs.visual.values()))
else:
raise Exception("observation must not be empty.")

feat = t.cat(feat_list, -1) # [T, B, *] or [B, *]

Expand Down

0 comments on commit 39d60c3

Please sign in to comment.