Skip to content

Commit

Permalink
refactor(unity): optimize unity wrapper (#25)
Browse files Browse the repository at this point in the history
- add `UnitySingleBehaviorInfo` in indexs.py
- remove BasicActionWrapper
- remove redundant identifiers
  • Loading branch information
StepNeverStop committed Jan 1, 2021
1 parent b92ed95 commit 747b3fa
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 99 deletions.
6 changes: 1 addition & 5 deletions rls/common/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def make_gym_env(env_kargs: Dict):

def make_unity_env(env_kargs: Dict):
from rls.envs.unity_wrapper import (BasicUnityEnvironment,
ScaleVisualWrapper,
BasicActionWrapper)
ScaleVisualWrapper)

copied_env_kargs = deepcopy(env_kargs)
env = BasicUnityEnvironment(copied_env_kargs)
Expand All @@ -40,7 +39,4 @@ def make_unity_env(env_kargs: Dict):
if copied_env_kargs['obs_scale']:
env = ScaleVisualWrapper(env)

env = BasicActionWrapper(env)
logger.debug('Unity BasicActionWrapper success.')

return env
14 changes: 7 additions & 7 deletions rls/common/train/unity.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def unity_train(env, model,

for episode in range(begin_episode, max_train_episode):
model.reset()
ret = env.reset(reset_config={})[env.first_bn]
ret = env.reset(reset_config={})
s = ret.corrected_vector
visual_s = ret.corrected_visual
dones_flag = np.zeros(n, dtype=float)
Expand All @@ -63,7 +63,7 @@ def unity_train(env, model,
while True:
step += 1
action = model.choose_action(s=s, visual_s=visual_s)
ret = env.step({env.first_bn: action}, step_config={})[env.first_bn]
ret = env.step(action, step_config={})

model.store_data(
s=s,
Expand Down Expand Up @@ -139,16 +139,16 @@ def unity_no_op(env, model,
if pre_fill_steps == 0:
return
model.reset()
ret = env.reset(reset_config={})[env.first_bn]
ret = env.reset(reset_config={})
s = ret.corrected_vector
visual_s = ret.corrected_visual

for _ in trange(0, pre_fill_steps, n, unit_scale=n, ncols=80, desc=desc, bar_format=bar_format):
if prefill_choose:
action = model.choose_action(s=s, visual_s=visual_s)
else:
action = env.random_action()[env.first_bn]
ret = env.step({env.first_bn: action}, step_config={})[env.first_bn]
action = env.random_action()
ret = env.step(action, step_config={})
model.no_op_store(
s=s,
visual_s=visual_s,
Expand All @@ -171,13 +171,13 @@ def unity_inference(env, model,

for episode in range(episodes):
model.reset()
ret = env.reset(reset_config={})[env.first_bn]
ret = env.reset(reset_config={})
while True:
action = model.choose_action(s=ret.corrected_vector,
visual_s=ret.corrected_visual,
evaluation=True)
model.partial_reset(ret.done)
ret = env.step({env.first_bn: action}, step_config={})[env.first_bn]
ret = env.step(action, step_config={})


def ma_unity_no_op(env, model,
Expand Down
2 changes: 1 addition & 1 deletion rls/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def initialize_unity(self):

def initialize_multi_unity(self):
# multi agents with unity
assert self.env.behavior_num > 1, 'if using ma* algorithms, number of brains must larger than 1'
assert self.env.is_multi_agents, 'assert self.env.is_multi_agents'

if 'Nstep' in self.buffer_args['type'] or 'Episode' in self.buffer_args['type']:
self.buffer_args[self.buffer_args['type']]['agents_num'] = self.env_args['env_num'] # TODO: 检查这里是否需要根据env_num设置
Expand Down
151 changes: 65 additions & 86 deletions rls/envs/unity_wrapper/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,22 @@
import numpy as np

from copy import deepcopy
from collections import (deque,
defaultdict)
from collections import defaultdict
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel
from mlagents_envs.base_env import (
ActionTuple,
ActionSpec # TODO
)
from mlagents_envs.base_env import (ActionTuple,
ActionSpec) # TODO

from rls.utils.logging_utils import get_logger
from rls.utils.display import colorize
logger = get_logger(__name__)

try:
import cv2
cv2.ocl.setUseOpenCL(False)
except:
logger.warning(colorize('opencv-python is needed to train visual-based model.', color='yellow'))
pass

from rls.common.yaml_ops import load_yaml
from rls.utils.np_utils import get_discrete_action_list
from rls.utils.indexs import (SingleAgentEnvArgs,
MultiAgentEnvArgs,
UnitySingleAgentReturn)
from rls.envs.unity_wrapper.core import (BasicWrapper,
ObservationWrapper,
from rls.envs.unity_wrapper.core import (ObservationWrapper,
ActionWrapper)


Expand Down Expand Up @@ -79,90 +67,94 @@ def reset(self, **kwargs):
for k, v in kwargs.get('reset_config', {}).items():
self._side_channels['float_properties_channel'].set_float_parameter(k, v)
self.env.reset()
return self.get_obs()
obs = self.get_obs()
return obs if self.is_multi_agents else obs[self.first_bn]

def step(self, actions, **kwargs):
'''
params: actions, type of dict or np.ndarray, if the type of actions is
not dict, then set those actions for the first behavior controller.
'''
for k, v in kwargs.get('step_config', {}).items():
self._side_channels['float_properties_channel'].set_float_parameter(k, v)
for k, v in actions.items():
if self.is_continuous[k]:
self.predesigned_actiontuples[k].add_continuous(v)

actions = deepcopy(actions)
if self.is_multi_agents:
assert isinstance(actions, dict)
for k, v in actions.items():
if self.is_continuous[k]:
self.empty_actiontuples[k].add_continuous(v)
else:
self.empty_actiontuples[k].add_discrete(self.discrete_action_lists[k][v])
self.env.set_actions(k, self.empty_actiontuples[k])
else:
# TODO: 优化
if self.is_continuous[self.first_bn]:
self.empty_actiontuples[self.first_bn].add_continuous(actions)
else:
self.predesigned_actiontuples[k].add_discrete(v)
self.env.set_actions(k, self.predesigned_actiontuples[k])
self.empty_actiontuples[self.first_bn].add_discrete(self.discrete_action_lists[self.first_bn][actions])
self.env.set_actions(self.first_bn, self.empty_actiontuples[self.first_bn])

self.env.step()
return self.get_obs()
obs = self.get_obs()
return obs if self.is_multi_agents else obs[self.first_bn]

def initialize_environment(self):
'''
初始化环境,获取必要的信息,如状态、动作维度等等
'''

# 获取所有behavior在Unity的名称
self.behavior_names = list(self.env.behavior_specs.keys())
self.first_bn = first_bn = self.behavior_names[0]
# NOTE: 为了根据behavior名称建立文件夹,需要替换名称中的问号符号 TODO: 优化
self.fixed_behavior_names = list(map(lambda x: x.replace('?', '_'), self.behavior_names))
self.first_fbn = self.fixed_behavior_names[0]

self.behavior_num = len(self.behavior_names)
self.is_multi_agents = self.behavior_num > 1

self.vector_idxs = {}
self.vector_dims = {}
self.visual_idxs = {}
self.visual_sources = {}
self.visual_resolutions = {}
self.s_dim = {}
self.a_dim = {}
self.is_multi_agents = len(self.behavior_names) > 1
self.first_bn = self.behavior_names[0]
self.first_fbn = self.first_bn.replace('?', '_')

self.behavior_agents, self.behavior_ids = self._get_real_agent_numbers_and_ids() # 得到每个环境控制几个智能体

self.vector_idxs = defaultdict(list)
self.vector_dims = defaultdict(list)
self.visual_idxs = defaultdict(list)
self.visual_sources = defaultdict(int)
self.visual_resolutions = defaultdict(list)
self.s_dim = defaultdict(int)
self.a_dim = defaultdict(int)
self.discrete_action_lists = {}
self.is_continuous = {}
self.continuous_sizes = {}
self.discrete_branchess = {}
self.discrete_sizes = {}
self.empty_actiontuples = {}

for bn, spec in self.env.behavior_specs.items():
# 向量输入
self.vector_idxs[bn] = [i for i, g in enumerate(spec.observation_shapes) if len(g) == 1]
self.vector_dims[bn] = [g[0] for g in spec.observation_shapes if len(g) == 1]
for i, shape in enumerate(spec.observation_shapes):
if len(shape) == 1:
self.vector_idxs[bn].append(i)
self.vector_dims[bn].append(shape[0])
elif len(shape) == 3:
self.visual_idxs[bn].append(i)
self.visual_resolutions[bn].append(list(shape)) # TODO: 适配多个不同size的图像输入,目前只支持1种类型的图像输入
else:
raise ValueError("shape of observation cannot be understood.")
self.s_dim[bn] = sum(self.vector_dims[bn])
# 图像输入
self.visual_idxs[bn] = [i for i, g in enumerate(spec.observation_shapes) if len(g) == 3]
self.visual_sources[bn] = len(self.visual_idxs[bn])
for g in spec.observation_shapes:
if len(g) == 3:
self.visual_resolutions[bn] = list(g)
break
else:
self.visual_resolutions[bn] = []
# 动作
# 连续
self.continuous_sizes[bn] = spec.action_spec.continuous_size
# 离散
self.discrete_branchess[bn] = spec.action_spec.discrete_branches
self.discrete_sizes[bn] = len(self.discrete_branchess[bn])

if self.continuous_sizes[bn] > 0 and self.discrete_sizes[bn] > 0:
raise NotImplementedError("doesn't support continuous and discrete actions simultaneously for now.")
elif self.continuous_sizes[bn] > 0:
self.a_dim[bn] = int(np.asarray(self.continuous_sizes[bn]).prod())

action_spec = spec.action_spec
if action_spec.is_continuous:
self.a_dim[bn] = action_spec.continuous_size
self.discrete_action_lists[bn] = None
self.is_continuous[bn] = True
else:
self.a_dim[bn] = int(np.asarray(self.discrete_branchess[bn]).prod())
self.discrete_action_lists[bn] = get_discrete_action_list(self.discrete_branchess[bn])
elif action_spec.is_discrete:
self.a_dim[bn] = int(np.asarray(action_spec.discrete_branches).prod())
self.discrete_action_lists[bn] = get_discrete_action_list(action_spec.discrete_branches)
self.is_continuous[bn] = False
else:
raise NotImplementedError("doesn't support continuous and discrete actions simultaneously for now.")

self.behavior_agents, self.behavior_ids = self._get_real_agent_numbers_and_ids() # 得到每个环境控制几个智能体
self.predesigned_actiontuples = {}
for bn in self.behavior_names:
self.predesigned_actiontuples[bn] = self.env.behavior_specs[bn].action_spec.random_action(n_agents=self.behavior_agents[bn])
self.empty_actiontuples[bn] = action_spec.empty_action(n_agents=self.behavior_agents[bn])

if self.is_multi_agents:
self.behavior_controls = {}
self.behavior_controls = defaultdict(int)
for bn in self.behavior_names:
self.behavior_controls[bn] = int(bn.split('#')[0])
self.env_copys = self.behavior_agents[first_bn] // self.behavior_controls[first_bn]
self.env_copys = self.behavior_agents[self.first_bn] // self.behavior_controls[self.first_bn]

@property
def EnvSpec(self):
Expand Down Expand Up @@ -195,9 +187,9 @@ def _get_real_agent_numbers_and_ids(self):
for _ in range(10):
for bn in self.behavior_names:
d, t = self.env.get_steps(bn)
behavior_agents[bn] = max(behavior_agents[bn], len(d))
# TODO: 检查t是否影响
if len(d) > len(behavior_ids[bn]):
behavior_agents[bn] = len(d)
behavior_ids[bn] = d.agent_id
self.env.set_actions(bn, self.env.behavior_specs[bn].action_spec.random_action(n_agents=len(d)))
self.env.step()
Expand Down Expand Up @@ -336,16 +328,3 @@ def func(self, vis):
for j in range(cameras):
vis[i, j] *= 255
return np.asarray(vis).astype(np.uint8)


class BasicActionWrapper(ActionWrapper):

def __init__(self, env):
super().__init__(env)

def action(self, actions):
actions = deepcopy(actions)
for bn in self.behavior_names:
if not self.is_continuous[bn]:
actions[bn] = self.discrete_action_lists[bn][actions[bn]]
return actions
7 changes: 7 additions & 0 deletions rls/utils/indexs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
MultiAgentEnvArgs = namedtuple('MultiAgentEnvArgs',
SingleAgentEnvArgs._fields + ('behavior_controls',))

UnitySingleBehaviorInfo = namedtuple('UnitySingleBehaviorInfo',
[
'behavior_name',
'n_agents_control',
'is_continuous'
])

UnitySingleAgentReturn = namedtuple('UnitySingleAgentReturn',
[
'vector',
Expand Down

0 comments on commit 747b3fa

Please sign in to comment.