Skip to content

Commit

Permalink
fix(multi-agents): implemented n-step replay buffer for multi-agents …
Browse files Browse the repository at this point in the history
…training. (#41,#25,#31)

1. change variable name from "is_lg_batch_size" to "can_sample"
2. optimized unity wrapper
3. optimized multi-agents replay buffers
  • Loading branch information
StepNeverStop committed Jul 2, 2021
1 parent e4c5da7 commit dbc6136
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 55 deletions.
45 changes: 41 additions & 4 deletions rls/algos/base/ma_off_policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# encoding: utf-8

import importlib
import numpy as np
import tensorflow as tf

Expand All @@ -22,17 +23,53 @@ def __init__(self, envspecs, **kwargs):
super().__init__(envspecs=envspecs, **kwargs)

self.buffer_size = int(kwargs.get('buffer_size', 10000))

self.n_step = int(kwargs.get('n_step', 1))
self.gamma = self.gamma ** self.n_step

self.burn_in_time_step = int(kwargs.get('burn_in_time_step', 10))
self.train_time_step = int(kwargs.get('train_time_step', 10))
self.episode_batch_size = int(kwargs.get('episode_batch_size', 32))
self.episode_buffer_size = int(kwargs.get('episode_buffer_size', 10000))

self.train_times_per_step = int(kwargs.get('train_times_per_step', 1))

def initialize_data_buffer(self) -> NoReturn:
'''
TODO: Annotation
'''
_buffer_args = dict(n_agents=self.n_agents_percopy, batch_size=self.batch_size, capacity=self.buffer_size)
default_buffer_args = load_yaml(f'rls/configs/off_policy_buffer.yaml')['MultiAgentExperienceReplay']
_buffer_args = {}
if self.use_rnn:
_type = 'EpisodeExperienceReplay'
_buffer_args.update(
batch_size=self.episode_batch_size,
capacity=self.episode_buffer_size,
burn_in_time_step=self.burn_in_time_step,
train_time_step=self.train_time_step,
n_copys=self.n_copys
)
else:
_type = 'ExperienceReplay'
_buffer_args.update(
batch_size=self.batch_size,
capacity=self.buffer_size
)
# if self.use_priority:
# raise NotImplementedError("multi agent algorithms now not support prioritized experience replay.")
if self.n_step > 1:
_type = 'NStep' + _type
_buffer_args.update(
n_step=self.n_step,
gamma=self.gamma,
n_copys=self.n_copys
)

default_buffer_args = load_yaml(f'rls/configs/off_policy_buffer.yaml')['MultiAgentExperienceReplay'][_type]
default_buffer_args.update(_buffer_args)
self.data = MultiAgentExperienceReplay(**default_buffer_args)

self.data = MultiAgentExperienceReplay(n_agents=self.n_agents_percopy,
single_agent_buffer_class=getattr(importlib.import_module(f'rls.memories.single_replay_buffers'), _type),
buffer_config=default_buffer_args)

def store_data(self, expss: List[BatchExperiences]) -> NoReturn:
"""
Expand Down Expand Up @@ -82,7 +119,7 @@ def _learn(self, function_dict: Dict = {}) -> NoReturn:
TODO: Annotation
'''

if self.data.is_lg_batch_size:
if self.data.can_sample:
self.intermediate_variable_reset()
data = self.get_transitions()

Expand Down
4 changes: 3 additions & 1 deletion rls/algos/base/ma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, envspecs: List[EnvGroupArgs], **kwargs):

self.memory_net_kwargs = dict(kwargs.get('memory_net_kwargs', {}))
self.memory_net_kwargs['network_type'] = MemoryNetworkType(self.memory_net_kwargs['network_type'])
self.use_rnn = bool(self.memory_net_kwargs.get('use_rnn', False))

self.writers = [self._create_writer(self.log_dir + f'_{i}') for i in range(self.n_agents_percopy)]

Expand Down Expand Up @@ -110,4 +111,5 @@ def write_training_summaries(self,
'''
super().write_training_summaries(global_step, summaries=summaries.get('model', {}), writer=self.writer)
for i, summary in summaries.items():
super().write_training_summaries(global_step, summaries=summary, writer=self.writers[i])
if i != 'model': # TODO: Optimization
super().write_training_summaries(global_step, summaries=summary, writer=self.writers[i])
9 changes: 6 additions & 3 deletions rls/algos/base/off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,20 @@ class Off_Policy(Policy):
def __init__(self, envspec, **kwargs):
super().__init__(envspec=envspec, **kwargs)
self.buffer_size = int(kwargs.get('buffer_size', 10000))
self.use_priority = kwargs.get('use_priority', False)

self.n_step = int(kwargs.get('n_step', 1))
self.gamma = self.gamma ** self.n_step

self.use_priority = kwargs.get('use_priority', False)
self.use_isw = bool(kwargs.get('use_isw', False))
self.train_times_per_step = int(kwargs.get('train_times_per_step', 1))

self.burn_in_time_step = int(kwargs.get('burn_in_time_step', 10))
self.train_time_step = int(kwargs.get('train_time_step', 10))
self.episode_batch_size = int(kwargs.get('episode_batch_size', 32))
self.episode_buffer_size = int(kwargs.get('episode_buffer_size', 10000))

self.train_times_per_step = int(kwargs.get('train_times_per_step', 1))

def initialize_data_buffer(self) -> NoReturn:
'''
TODO: Annotation
Expand Down Expand Up @@ -124,7 +127,7 @@ def _learn(self, function_dict: Dict = {}) -> NoReturn:
_summary = function_dict.get('summary_dict', {}) # 记录输出到tensorboard的词典
_use_stack = function_dict.get('use_stack', False)

if self.data.is_lg_batch_size:
if self.data.can_sample:
self.intermediate_variable_reset()
data = self.get_transitions()

Expand Down
2 changes: 1 addition & 1 deletion rls/algos/hierarchical/hiro.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_subgoal(self, s):
def learn(self, **kwargs):
self.train_step = kwargs.get('train_step')
for i in range(self.train_times_per_step):
if self.data_low.is_lg_batch_size and self.data_high.is_lg_batch_size:
if self.data_low.can_sample and self.data_high.can_sample:
self.intermediate_variable_reset()
low_data = self.get_transitions(self.data_low)
high_data = self.get_transitions(self.data_high)
Expand Down
4 changes: 2 additions & 2 deletions rls/common/train/unity.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def unity_inference(env, model,
def ma_unity_no_op(env, model,
pre_fill_steps: int,
prefill_choose: bool,
desc: str = 'Pre-filling',
real_done: bool = True) -> NoReturn:
real_done: bool,
desc: str = 'Pre-filling') -> NoReturn:
assert isinstance(pre_fill_steps, int) and pre_fill_steps >= 0, 'no_op.steps must have type of int and larger than/equal 0'
n = env._n_copys

Expand Down
3 changes: 2 additions & 1 deletion rls/configs/algorithms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,9 @@ maddpg:
actor_lr: 5.0e-4
critic_lr: 1.0e-3
discrete_tau: 1.0
batch_size: 32
batch_size: 4
buffer_size: 100000
n_step: 4
network_settings:
actor_continuous: [64, 64]
actor_discrete: [64, 64]
Expand Down
30 changes: 24 additions & 6 deletions rls/configs/off_policy_buffer.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
ExperienceReplay: {}
ExperienceReplay: &ExperienceReplay {}

PrioritizedExperienceReplay:
PrioritizedExperienceReplay: &PrioritizedExperienceReplay
alpha: 0.6 # priority
beta: 0.4 # importance sampling ratio
epsilon: 0.01
global_v: false

NStepExperienceReplay:
NStepExperienceReplay: &NStepExperienceReplay
n_step: 4

NStepPrioritizedExperienceReplay:
NStepPrioritizedExperienceReplay: &NStepPrioritizedExperienceReplay
alpha: 0.6
beta: 0.4
epsilon: 0.01
global_v: false
n_step: 4

# off-policy with rnn
EpisodeExperienceReplay:
EpisodeExperienceReplay: &EpisodeExperienceReplay
burn_in_time_step: 20
train_time_step: 40 # null

MultiAgentExperienceReplay: {}
# Multi-Agents

MultiAgentCentralExperienceReplay: {} # TODO

MultiAgentExperienceReplay:
ExperienceReplay:
<<: *ExperienceReplay

PrioritizedExperienceReplay:
<<: *PrioritizedExperienceReplay

NStepExperienceReplay:
<<: *NStepExperienceReplay

NStepPrioritizedExperienceReplay:
<<: *NStepPrioritizedExperienceReplay

EpisodeExperienceReplay:
<<: *EpisodeExperienceReplay
2 changes: 1 addition & 1 deletion rls/distribute/apex/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, learner_ip, learner_port, buffer, lock):

def run(self):
while True:
if self.buffer.is_lg_batch_size:
if self.buffer.can_sample:
with self.lock:
exps, idxs = self.buffer.sample(return_index=True)
prios = self.buffer.get_IS_w().reshape(-1, 1)
Expand Down
43 changes: 32 additions & 11 deletions rls/envs/unity_wrapper/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,15 @@ def get_obs(self, behavior_names=None):
解析环境反馈的信息,将反馈信息分为四部分:向量、图像、奖励、done信号
'''
behavior_names = behavior_names or self.behavior_names
rets = []

# TODO: optimization
whole_done = np.full(self._n_copys, False)
whole_info_max_step = np.full(self._n_copys, False)
whole_info_real_done = np.full(self._n_copys, False)
all_corrected_obs = []
all_obs = []
all_reward = []

for bn in behavior_names:
n = self.behavior_agents[bn]
ids = self.behavior_ids[bn]
Expand Down Expand Up @@ -235,17 +243,30 @@ def get_obs(self, behavior_names=None):
reward = np.asarray(reward)
done = np.asarray(done)

rets.extend([
for idxs in self.batch_idx_for_behaviors[bn]:
whole_done = np.logical_or(whole_done, done[idxs])
whole_info_max_step = np.logical_or(whole_info_max_step, info_max_step[idxs])
whole_info_real_done = np.logical_or(whole_info_real_done, info_real_done[idxs])

all_corrected_obs.append(ModelObservations(vector=self.vector_info_type[bn](*[corrected_obs[vi][idxs] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[corrected_obs[vi][idxs] for vi in self.visual_idxs[bn]])))
all_obs.append(ModelObservations(vector=self.vector_info_type[bn](*[obs[vi][idxs] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[obs[vi][idxs] for vi in self.visual_idxs[bn]])))
all_reward.append(reward[idxs])
# all_info.append(dict(max_step=info_max_step[idxs], real_done=info_real_done[idxs]))

rets = []
for corrected_obs, obs, reward in zip(all_corrected_obs, all_obs, all_reward):
rets.append(
SingleModelInformation(
corrected_obs=ModelObservations(vector=self.vector_info_type[bn](*[corrected_obs[vi][idxs] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[corrected_obs[vi][idxs] for vi in self.visual_idxs[bn]])),
obs=ModelObservations(vector=self.vector_info_type[bn](*[obs[vi][idxs] for vi in self.vector_idxs[bn]]),
visual=self.visual_info_type[bn](*[obs[vi][idxs] for vi in self.visual_idxs[bn]])),
reward=reward[idxs],
done=done[idxs],
info=dict(max_step=info_max_step[idxs], real_done=info_real_done[idxs])
) for idxs in self.batch_idx_for_behaviors[bn]
])
corrected_obs=corrected_obs,
obs=obs,
reward=reward,
done=whole_done,
info=dict(max_step=whole_info_max_step,
real_done=whole_info_real_done)
)
)
return rets

def random_action(self, is_single=True):
Expand Down
24 changes: 24 additions & 0 deletions rls/memories/base_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,32 @@ def sample(self) -> Any:
def add(self, exps: BatchExperiences) -> Any:
pass

@property
def can_sample(self) -> bool:
return self._size > self.batch_size

def is_empty(self) -> bool:
return self._size == 0

def update(self, *args) -> Any:
pass


class MultiAgentReplayBuffer(ABC):
def __init__(self,
n_agents: int):
assert isinstance(n_agents, int) and n_agents >= 0, 'n_agents must be int and larger than 0'
self._n_agents = n_agents

@abstractmethod
def sample(self) -> Any:
pass

@abstractmethod
def add(self, expss: List[BatchExperiences]) -> Any:
pass

@property
@abstractmethod
def can_sample(self) -> bool:
pass
37 changes: 29 additions & 8 deletions rls/memories/multi_replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,39 @@
import numpy as np

from typing import (List,
Dict,
NoReturn)

from rls.memories.base_replay_buffer import ReplayBuffer
from rls.memories.base_replay_buffer import (ReplayBuffer,
MultiAgentReplayBuffer)
from rls.utils.specs import (BatchExperiences,
NamedTupleStaticClass)


class MultiAgentExperienceReplay(ReplayBuffer):
class MultiAgentExperienceReplay(MultiAgentReplayBuffer):

def __init__(self,
n_agents: int,
single_agent_buffer_class: ReplayBuffer,
buffer_config: Dict = {}):
super().__init__(n_agents)
self._buffers = [single_agent_buffer_class(**buffer_config) for _ in range(self._n_agents)]

def add(self, expss: List[BatchExperiences]) -> NoReturn:
for exps, buffer in zip(expss, self._buffers):
buffer.add(exps)

def sample(self) -> List[BatchExperiences]:
idxs_info = self._buffers[0].generate_random_sample_idxs()
expss = [buffer.sample_from_idxs(idxs_info) for buffer in self._buffers]
return expss

@property
def can_sample(self):
return self._buffers[0].can_sample


class MultiAgentCentralExperienceReplay(ReplayBuffer):
def __init__(self,
n_agents: int,
batch_size: int,
Expand All @@ -28,10 +53,6 @@ def add(self, expss: List[BatchExperiences]) -> NoReturn:
self._store_op(i, exp)
self.update_rb_after_add()

# for i, exps in enumerate(expss):
# for exp in NamedTupleStaticClass.unpack(exps):
# self._store_op(i, exp)

def _store_op(self, i, exp: BatchExperiences) -> NoReturn:
self._buffers[i, self._data_pointer] = exp
# self.update_rb_after_add()
Expand All @@ -40,7 +61,7 @@ def sample(self) -> BatchExperiences:
'''
change [[s, a, r],[s, a, r]] to [[s, s],[a, a],[r, r]]
'''
n_sample = self.batch_size if self.is_lg_batch_size else self._size
n_sample = self.batch_size if self.can_sample else self._size
idx = np.random.randint(0, self._size, n_sample)
t = self._buffers[:, idx]
return [NamedTupleStaticClass.pack(_t.tolist()) for _t in t]
Expand All @@ -64,7 +85,7 @@ def size(self) -> int:
return self._size

@property
def is_lg_batch_size(self) -> bool:
def can_sample(self) -> bool:
return self._size > self.batch_size

@property
Expand Down
Loading

0 comments on commit dbc6136

Please sign in to comment.