Skip to content

Commit

Permalink
fix&feat: optimized training framework. (#41, #34, #33, #31, #25)
Browse files Browse the repository at this point in the history
1. fixed n-step replay buffer
2. reconstruct representation net
3. remove 'use_stack'
4. implement multi-agent algorithms with shared parameters
5. optimized agent network
  • Loading branch information
StepNeverStop committed Jul 4, 2021
1 parent dbc6136 commit 0a8d247
Show file tree
Hide file tree
Showing 42 changed files with 612 additions and 594 deletions.
42 changes: 21 additions & 21 deletions rls/algos/base/ma_off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,30 @@ def initialize_data_buffer(self) -> NoReturn:
TODO: Annotation
'''
_buffer_args = {}
if self.use_rnn:
_type = 'EpisodeExperienceReplay'
# if self.use_rnn:
# _type = 'EpisodeExperienceReplay'
# _buffer_args.update(
# batch_size=self.episode_batch_size,
# capacity=self.episode_buffer_size,
# burn_in_time_step=self.burn_in_time_step,
# train_time_step=self.train_time_step,
# n_copys=self.n_copys
# )
# else:
_type = 'ExperienceReplay'
_buffer_args.update(
batch_size=self.batch_size,
capacity=self.buffer_size
)
# if self.use_priority:
# raise NotImplementedError("multi agent algorithms now not support prioritized experience replay.")
if self.n_step > 1:
_type = 'NStep' + _type
_buffer_args.update(
batch_size=self.episode_batch_size,
capacity=self.episode_buffer_size,
burn_in_time_step=self.burn_in_time_step,
train_time_step=self.train_time_step,
n_step=self.n_step,
gamma=self.gamma,
n_copys=self.n_copys
)
else:
_type = 'ExperienceReplay'
_buffer_args.update(
batch_size=self.batch_size,
capacity=self.buffer_size
)
# if self.use_priority:
# raise NotImplementedError("multi agent algorithms now not support prioritized experience replay.")
if self.n_step > 1:
_type = 'NStep' + _type
_buffer_args.update(
n_step=self.n_step,
gamma=self.gamma,
n_copys=self.n_copys
)

default_buffer_args = load_yaml(f'rls/configs/off_policy_buffer.yaml')['MultiAgentExperienceReplay'][_type]
default_buffer_args.update(_buffer_args)
Expand Down
13 changes: 2 additions & 11 deletions rls/algos/base/ma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf

from abc import abstractmethod
from collections import defaultdict
from typing import (List,
Dict,
Union,
Expand Down Expand Up @@ -35,17 +36,7 @@ def __init__(self, envspecs: List[EnvGroupArgs], **kwargs):
self.max_train_step = int(kwargs.get('max_train_step', 1000))
self.delay_lr = bool(kwargs.get('decay_lr', True))

self.vector_net_kwargs = dict(kwargs.get('vector_net_kwargs', {}))
self.vector_net_kwargs['network_type'] = VectorNetworkType(self.vector_net_kwargs['network_type'])

self.visual_net_kwargs = dict(kwargs.get('visual_net_kwargs', {}))
self.visual_net_kwargs['network_type'] = VisualNetworkType(self.visual_net_kwargs['network_type'])

self.encoder_net_kwargs = dict(kwargs.get('encoder_net_kwargs', {}))

self.memory_net_kwargs = dict(kwargs.get('memory_net_kwargs', {}))
self.memory_net_kwargs['network_type'] = MemoryNetworkType(self.memory_net_kwargs['network_type'])
self.use_rnn = bool(self.memory_net_kwargs.get('use_rnn', False))
self.representation_net_params = dict(kwargs.get('representation_net_params', defaultdict(dict)))

self.writers = [self._create_writer(self.log_dir + f'_{i}') for i in range(self.n_agents_percopy)]

Expand Down
25 changes: 2 additions & 23 deletions rls/algos/base/off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def initialize_data_buffer(self) -> NoReturn:
capacity=self.episode_buffer_size,
burn_in_time_step=self.burn_in_time_step,
train_time_step=self.train_time_step,
agents_num=self.n_copys
n_copys=self.n_copys
)
else:
_type = 'ExperienceReplay'
Expand All @@ -68,7 +68,7 @@ def initialize_data_buffer(self) -> NoReturn:
_buffer_args.update(
n_step=self.n_step,
gamma=self.gamma,
agents_num=self.n_copys
n_copys=self.n_copys
)

default_buffer_args = load_yaml(f'rls/configs/off_policy_buffer.yaml')[_type]
Expand Down Expand Up @@ -125,7 +125,6 @@ def _learn(self, function_dict: Dict = {}) -> NoReturn:
TODO: Annotation
'''
_summary = function_dict.get('summary_dict', {}) # 记录输出到tensorboard的词典
_use_stack = function_dict.get('use_stack', False)

if self.data.can_sample:
self.intermediate_variable_reset()
Expand Down Expand Up @@ -160,15 +159,6 @@ def _learn(self, function_dict: Dict = {}) -> NoReturn:
_isw = tf.constant(value=1., dtype=self._tf_data_type)
# --------------------------------------

# --------------------------------------如果使用RNN, 就将s和s‘状态进行拼接处理
if _use_stack:
if self.use_rnn:
obs = ModelObservations.stack_rnn(data.obs, data.obs_, episode_batch_size=self.episode_batch_size)
else:
obs = ModelObservations.stack(data.obs, data.obs_)
data = data._replace(obs=obs)
# --------------------------------------

# --------------------------------------训练主程序,返回可能用于PER权重更新的TD error,和需要输出tensorboard的信息
td_error, summaries = self._train(data, _isw, cell_state)
# --------------------------------------
Expand Down Expand Up @@ -200,15 +190,10 @@ def _apex_learn(self, function_dict: Dict, data: BatchExperiences, priorities) -
TODO: Annotation
'''
_summary = function_dict.get('summary_dict', {}) # 记录输出到tensorboard的词典
_use_stack = function_dict.get('use_stack', False)

self.intermediate_variable_reset()
data = self._data_process2dict(data=data)

if _use_stack:
obs = [tf.concat([o, o_], axis=0) for o, o_ in zip(data.obs, data.obs_)] # [B, N] => [2*B, N]
data = data._replace(obs=data.obs.__class__._make(obs))

cell_state = (None,)

if self.use_curiosity:
Expand All @@ -231,14 +216,8 @@ def _apex_cal_td(self, data: BatchExperiences, function_dict: Dict = {}) -> np.n
'''
TODO: Annotation
'''
_use_stack = function_dict.get('use_stack', False)

data = self._data_process2dict(data=data)

if _use_stack:
obs = [tf.concat([o, o_], axis=0) for o, o_ in zip(data.obs, data.obs_)] # [B, N] => [2*B, N]
data = data._replace(obs=data.obs.__class__._make(obs))

cell_state = (None,)
td_error = self._cal_td(data, cell_state)
return np.squeeze(td_error.numpy())
2 changes: 1 addition & 1 deletion rls/algos/base/on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, envspec, **kwargs):
self.rnn_time_step = int(kwargs.get('rnn_time_step', 8))

def initialize_data_buffer(self, store_data_type=BatchExperiences, sample_data_type=BatchExperiences) -> NoReturn:
self.data = DataBuffer(n_copys=self.n_copys, rnn_cell_nums=self.cell_nums,
self.data = DataBuffer(n_copys=self.n_copys, rnn_cell_nums=self.rnn_cell_nums,
batch_size=self.batch_size, rnn_time_step=self.rnn_time_step,
store_data_type=store_data_type, sample_data_type=sample_data_type)

Expand Down
47 changes: 16 additions & 31 deletions rls/algos/base/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf

from abc import abstractmethod
from collections import defaultdict
from typing import (Union,
List,
Callable,
Expand All @@ -18,10 +19,7 @@
from rls.nn.learningrate import ConsistentLearningRate
from rls.utils.vector_runing_average import (DefaultRunningAverage,
SimpleRunningAverage)
from rls.utils.specs import (EnvGroupArgs,
VectorNetworkType,
VisualNetworkType,
MemoryNetworkType)
from rls.utils.specs import EnvGroupArgs
from rls.utils.build_networks import DefaultRepresentationNetwork
from rls.nn.modules import CuriosityModel

Expand All @@ -47,47 +45,27 @@ def __init__(self, envspec: EnvGroupArgs, **kwargs):
self.max_train_step = int(kwargs.get('max_train_step', 1000))
self.delay_lr = bool(kwargs.get('decay_lr', True))

self.vector_net_kwargs = dict(kwargs.get('vector_net_kwargs', {}))
self.vector_net_kwargs['network_type'] = VectorNetworkType(self.vector_net_kwargs['network_type'])
self.representation_net_params = dict(kwargs.get('representation_net_params', defaultdict(dict)))
self.use_rnn = bool(self.representation_net_params.get('use_rnn', False))

self.visual_net_kwargs = dict(kwargs.get('visual_net_kwargs', {}))
self.visual_net_kwargs['network_type'] = VisualNetworkType(self.visual_net_kwargs['network_type'])

self.encoder_net_kwargs = dict(kwargs.get('encoder_net_kwargs', {}))

self.memory_net_kwargs = dict(kwargs.get('memory_net_kwargs', {}))
self.rnn_type = self.memory_net_kwargs['network_type'] = MemoryNetworkType(self.memory_net_kwargs['network_type'])
self.use_rnn = bool(self.memory_net_kwargs.get('use_rnn', False))
self.cell_nums = 2 if self.rnn_type == MemoryNetworkType.LSTM else 1

self._representation_net = self._create_representation_net(name='_representation_net')
self._representation_net = DefaultRepresentationNetwork(name='_representation_net',
obs_spec=self.obs_spec,
representation_net_params=self.representation_net_params)

self.use_curiosity = bool(kwargs.get('use_curiosity', False))
if self.use_curiosity:
self.curiosity_eta = float(kwargs.get('curiosity_reward_eta'))
self.curiosity_lr = float(kwargs.get('curiosity_lr'))
self.curiosity_beta = float(kwargs.get('curiosity_beta'))
self.curiosity_model = CuriosityModel(self.obs_spec,
self.vector_net_kwargs,
self.visual_net_kwargs,
self.encoder_net_kwargs,
self.memory_net_kwargs,
self.representation_net_params,
self.is_continuous,
self.a_dim,
eta=self.curiosity_eta,
lr=self.curiosity_lr,
beta=self.curiosity_beta)
self._all_params_dict.update(curiosity_model=self.curiosity_model)

def _create_representation_net(self, name: str = 'default'):
# TODO: Added changeable command
return DefaultRepresentationNetwork(obs_spec=self.obs_spec,
name=name,
vector_net_kwargs=self.vector_net_kwargs,
visual_net_kwargs=self.visual_net_kwargs,
encoder_net_kwargs=self.encoder_net_kwargs,
memory_net_kwargs=self.memory_net_kwargs)

def init_lr(self, lr: float) -> Callable:
if self.delay_lr:
return tf.keras.optimizers.schedules.PolynomialDecay(lr, self.max_train_step, 1e-10, power=1.0)
Expand All @@ -104,9 +82,16 @@ def reset(self) -> NoReturn:
'''reset model for each new episode.'''
self.cell_state = self.next_cell_state = self.initial_cell_state(batch=self.n_copys)

@property
def rnn_cell_nums(self):
if self.use_rnn:
return self._representation_net.memory_net.cell_nums
else:
return 0

def initial_cell_state(self, batch: int) -> Tuple[tf.Tensor]:
if self.use_rnn:
return tuple(tf.zeros((batch, self._representation_net.h_dim), dtype=tf.float32) for _ in range(self.cell_nums))
return self._representation_net.memory_net.initial_cell_state(batch=batch)
return (None,)

def get_cell_state(self) -> Tuple[Optional[tf.Tensor]]:
Expand Down
14 changes: 8 additions & 6 deletions rls/algos/hierarchical/aoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def choose_action(self, obs, evaluation=False):
@tf.function
def _get_action(self, obs, cell_state, options):
with tf.device(self.device):
(q, pi, beta), cell_state = self.net(obs, cell_state=cell_state) # [B, P], [B, P, A], [B, P], [B, P]
ret = self.net(obs, cell_state=cell_state) # [B, P], [B, P, A], [B, P], [B, P]
(q, pi, beta) = ret['value']
options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P]
options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1]
pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, A]
Expand All @@ -153,7 +154,7 @@ def _get_action(self, obs, cell_state, options):
beta_probs = tf.reduce_sum(beta * options_onehot, axis=1) # [B, P] => [B,]
beta_dist = tfp.distributions.Bernoulli(probs=beta_probs)
new_options = tf.where(beta_dist.sample() < 1, options, max_options) # <1 则不改变op, =1 则改变op
return sample_op, value, log_prob, beta_adv, new_options, max_options, cell_state
return sample_op, value, log_prob, beta_adv, new_options, max_options, ret['cell_state']

def store_data(self, exps: BatchExperiences):
# self._running_average()
Expand All @@ -169,10 +170,11 @@ def store_data(self, exps: BatchExperiences):
def _get_value(self, obs, options, cell_state):
options = tf.cast(options, tf.int32)
with tf.device(self.device):
(q, _, _), cell_state = self.net(obs, cell_state=cell_state)
ret = self.net(obs, cell_state=cell_state)
(q, _, _) = ret['value']
options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P]
value = q_o = tf.reduce_sum(q * options_onehot, axis=-1, keepdims=True) # [B, 1]
return value, cell_state
return value, ret['cell_state']

def calculate_statistics(self):
init_value, self.cell_state = self._get_value(self.data.last_data('obs_'), self.data.last_data('options'), cell_state=self.cell_state)
Expand Down Expand Up @@ -223,8 +225,8 @@ def train(self, BATCH, cell_state, kl_coef):
options = tf.cast(BATCH.options, tf.int32)
with tf.device(self.device):
with tf.GradientTape() as tape:
(q, pi, beta), cell_state = self.net(BATCH.obs, cell_state=cell_state) # [B, P], [B, P, A], [B, P], [B, P]

ret = self.net(BATCH.obs, cell_state=cell_state) # [B, P], [B, P, A], [B, P], [B, P]
(q, pi, beta) = ret['value']
options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P]
options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1]
last_options_onehot = tf.one_hot(last_options, self.options_num, dtype=tf.float32) # [B,] => [B, P]
Expand Down
31 changes: 15 additions & 16 deletions rls/algos/hierarchical/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def _create_net(name, representation_net=None): return ValueNetwork(
value_net_kwargs=dict(output_shape=self.options_num, network_settings=network_settings['q'])
)
self.q_net = _create_net('q_net', self._representation_net)
self._representation_target_net = self._create_representation_net('_representation_target_net')
self.q_target_net = _create_net('q_target_net', self._representation_target_net)
self.q_target_net = _create_net('q_target_net', self._representation_net._copy())

self.intra_option_net = ValueNetwork(
name='intra_option_net',
Expand Down Expand Up @@ -134,9 +133,9 @@ def choose_action(self, obs, evaluation=False):
@tf.function
def _get_action(self, obs, cell_state, options):
with tf.device(self.device):
feat, cell_state = self._representation_net(obs, cell_state=cell_state)
q = self.q_net.value_net(feat) # [B, P]
pi = self.intra_option_net.value_net(feat) # [B, P, A]
ret = self.q_net(obs, cell_state=cell_state)
q = ret['value'] # [B, P]
pi = self.intra_option_net.value_net(ret['feat']) # [B, P, A]
options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B, P]
options_onehot_expanded = tf.expand_dims(options_onehot, axis=-1) # [B, P, 1]
pi = tf.reduce_sum(pi * options_onehot_expanded, axis=1) # [B, A]
Expand All @@ -148,10 +147,10 @@ def _get_action(self, obs, cell_state, options):
pi = pi / self.boltzmann_temperature
dist = tfp.distributions.Categorical(logits=pi) # [B, ]
a = dist.sample()
interests = self.interest_net.value_net(feat) # [B, P]
interests = self.interest_net.value_net(ret['feat']) # [B, P]
op_logits = interests * q # [B, P] or tf.nn.softmax(q)
new_options = tfp.distributions.Categorical(logits=op_logits).sample()
return a, new_options, cell_state
return a, new_options, ret['cell_state']

def _target_params_update(self):
if self.global_step % self.assign_interval == 0:
Expand All @@ -176,21 +175,21 @@ def _train(self, BATCH, isw, cell_state):
options = tf.cast(BATCH.options, tf.int32)
with tf.device(self.device):
with tf.GradientTape(persistent=True) as tape:
feat, _ = self._representation_net(BATCH.obs, cell_state=cell_state)
feat_, _ = self._representation_target_net(BATCH.obs_, cell_state=cell_state)
q = self.q_net.value_net(feat) # [B, P]
pi = self.intra_option_net.value_net(feat) # [B, P, A]
beta = self.termination_net.value_net(feat) # [B, P]
q_next = self.q_target_net.value_net(feat_) # [B, P], [B, P, A], [B, P]
beta_next = self.termination_net.value_net(feat_) # [B, P]
interests = self.interest_net.value_net(feat) # [B, P]
ret = self.q_net(BATCH.obs, cell_state=cell_state)
ret_ = self.q_target_net(BATCH.obs_, cell_state=cell_state)
q = ret['value'] # [B, P]
pi = self.intra_option_net.value_net(ret['feat']) # [B, P, A]
beta = self.termination_net.value_net(ret['feat']) # [B, P]
interests = self.interest_net.value_net(ret['feat']) # [B, P]
q_next = ret_['value'] # [B, P], [B, P, A], [B, P]
beta_next = self.termination_net.value_net(ret_['feat']) # [B, P]
options_onehot = tf.one_hot(options, self.options_num, dtype=tf.float32) # [B,] => [B, P]

q_s = qu_eval = tf.reduce_sum(q * options_onehot, axis=-1, keepdims=True) # [B, 1]
beta_s_ = tf.reduce_sum(beta_next * options_onehot, axis=-1, keepdims=True) # [B, 1]
q_s_ = tf.reduce_sum(q_next * options_onehot, axis=-1, keepdims=True) # [B, 1]
if self.double_q:
q_ = self.q_net.value_net(feat) # [B, P], [B, P, A], [B, P]
q_ = self.q_net(BATCH.obs_, cell_state=cell_state)['value'] # [B, P], [B, P, A], [B, P]
max_a_idx = tf.one_hot(tf.argmax(q_, axis=-1), self.options_num, dtype=tf.float32) # [B, P] => [B, ] => [B, P]
q_s_max = tf.reduce_sum(q_next * max_a_idx, axis=-1, keepdims=True) # [B, 1]
else:
Expand Down
Loading

0 comments on commit 0a8d247

Please sign in to comment.