Skip to content

Commit

Permalink
perf: optimized code-style. (#45, #41, #39, #35, #25)
Browse files Browse the repository at this point in the history
1. removed sarl off-policy algorithm pd_ddpg, 'cause it's not in main stream
2. updated README
3. removed `iql` and added script `IndependentMA.py` instead to implement independent multi-agent algorithms
4. optimized summary writing
5. move NamedDict from 'rls.common.config' to 'rls.common.specs'
6. updated example config
7. updated `.gitignore`
8. added property `is_multi` to identify whether training task is for sarl or marl for both unity and gym
9. reconstructed inheritance relationships between algorithms and their's superclass
10. removed `1.e+18` in yaml files and use a large integer number instead, 'cause we want a large integer rather than float
  • Loading branch information
StepNeverStop committed Jul 28, 2021
1 parent 742da08 commit a3c989a
Show file tree
Hide file tree
Showing 30 changed files with 329 additions and 680 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,6 @@ venv.bak/
test/
test.py
.vscode/
data/
videos/
unitylog.txt
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ For now, these algorithms are available:
| Rainbow || ||| rainbow |
| DPG ||||| dpg |
| DDPG ||||| ddpg |
| PD-DDPG ||||| pd_ddpg |
| TD3 ||||| td3 |
| SAC(has V network) ||||| sac_v |
| SAC ||||| sac |
Expand All @@ -169,7 +168,6 @@ For now, these algorithms are available:
| IOC ||||| ioc |
| HIRO ||| | | hiro |
| CURL |||| | curl |
| IQL || || | iql |
| VDN || || | vdn |
| MADDPG |||| | maddpg |

Expand All @@ -178,7 +176,7 @@ For now, these algorithms are available:
```python
"""
usage: run.py [-h] [-c COPYS] [--seed SEED] [-r] [-p {gym,unity}]
[-a {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,pd_ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn,iql}]
[-a {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn}]
[-d DEVICE] [-i] [-l LOAD_PATH] [-m MODELS] [-n NAME] [-s SAVE_FREQUENCY] [--apex {learner,worker,buffer,evaluator}] [--config-file CONFIG_FILE]
[--store-dir STORE_DIR] [--episode-length EPISODE_LENGTH] [--prefill-steps PREFILL_STEPS] [--prefill-choose] [--hostname] [--no-save] [--info INFO]
[-e ENV] [-f FILE_NAME]
Expand All @@ -191,7 +189,7 @@ optional arguments:
-r, --render whether render game interface
-p {gym,unity}, --platform {gym,unity}
specify the platform of training environment
-a {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,pd_ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn,iql}, --algorithm {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,pd_ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn,iql}
-a {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn}, --algorithm {pg,trpo,ppo,a2c,cem,aoc,ppoc,qs,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn,curl,oc,ioc,hiro,maddpg,vdn}
specify the training algorithm
-d DEVICE, --device DEVICE
specify the device that operate Torch.Tensor
Expand Down Expand Up @@ -225,7 +223,7 @@ optional arguments:
"""
Example:
python run.py
python run.py --config-file 'rls/configs/examples/gym_config.yaml'
python run.py --config-file rls/configs/examples/gym_config.yaml
python run.py -p gym -a dqn -e CartPole-v0 -c 12 -n dqn_cartpole --no-save
python run.py -p unity -a ppo -n run_with_unity
python run.py -p unity --file-name /root/env/3dball.app -a sac -n run_with_execution_file
Expand Down
45 changes: 0 additions & 45 deletions rls/algos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,27 +252,6 @@
'''
)

register(
name='pd_ddpg',
folder='single',
is_multi=False,
algo_class='PD_DDPG',
policy_mode='off-policy',
logo='''
     OOOOOOOO        OOOOOOOOOO                          OOOOOOOOOO        OOOOOOOOOO          OOOOOOOO            OOOOO O    
      OOOOOOOO        OOOOOOOOOO                          OOOOOOOOOO        OOOOOOOOOO          OOOOOOOO         OOOOOOOOO    
       OO  OOOO        OO    OOOO                          OO    OOOO        OO    OOOO          OO  OOOO       OOOO    OO    
       OO   OOO        OO     OOOO                         OO     OOOO       OO     OOOO         OO   OOO      OOOO      O    
       OO  OOOO        OO     OOOO    OOO   OOO  OOO       OO     OOOO       OO     OOOO         OO  OOOO      OOO            
       OOOOOOO         OO      OOO    OOOO OOOO  OOOO      OO      OOO       OO      OOO         OOOOOOO       OOO    OOOOOO  
       OOOOOO          OO      OOO    OOOO OOOO  OOOO      OO      OOO       OO      OOO         OOOOOO        OOO     OOOOO  
       OO              OO     OOOO    OOOO  OOO  OOO       OO     OOOO       OO     OOOO         OO            OOO      OO    
       OO              OO     OOO                          OO     OOO        OO     OOO          OO            OOO      OO    
       OO              OO   OOOOO                          OO   OOOOO        OO   OOOOO          OO             OOO     OO    
     OOOOOO          OOOOOOOOOOO                         OOOOOOOOOOO       OOOOOOOOOOO         OOOOOO           OOOOO OOOO    
'''
)

register(
name='td3',
folder='single',
Expand Down Expand Up @@ -750,27 +729,3 @@
         OO
'''
)

register(
name='iql',
algo_class='IQL',
policy_mode='off-policy',
folder='multi',
is_multi=True,
logo='''
      OOOOO             OOOOOO          OOOOO         
       OOO             OOOOOOOOO         OOO          
        OO            OOOO   OOO          OO          
        OO            OOO    OOOO         OO          
        OO           OOO      OOO         OO          
        OO           OOO      OOO         OO          
        OO           OOO      OOO         OO          
        OO           OOO      OOO         OO          
        OO            OOO     OOO         OO     O    
        OO            OOOO   OOO          OO    OO    
      OOOOO            OOOOOOOOO        OOOOOOOOO     
                        OOOOOO                        
                          OOOO                        
                            OOOO
'''
)
22 changes: 5 additions & 17 deletions rls/algos/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
self.device = device
logger.info(colorize(f"PyTorch Tensor Device: {self.device}"))

self.cp_dir, self.log_dir, self.excel_dir = [os.path.join(base_dir, i) for i in ['model', 'log', 'excel']]
self.cp_dir, self.log_dir = [os.path.join(base_dir, i) for i in ['model', 'log']]

if not self.no_save:
check_or_create(self.cp_dir, 'checkpoints(models)')
Expand Down Expand Up @@ -121,22 +121,10 @@ def write_training_info(self, data: Dict) -> NoReturn:
with open(f'{self.base_dir}/step.json', 'w') as f:
json.dump(data, f)

def writer_summary(self,
global_step: Union[int, t.Tensor],
summaries: Dict = {},
writer: Optional[SummaryWriter] = None) -> NoReturn:
"""
record the data used to show in the tensorboard
"""
if not self.no_save:
writer = writer or self.writer
for k, v in summaries.items():
writer.add_scalar('AGENT/' + k, v, global_step=global_step)

def write_training_summaries(self,
global_step: Union[int, t.Tensor],
summaries: Dict = {},
writer: Optional[SummaryWriter] = None) -> NoReturn:
def write_summaries(self,
global_step: Union[int, t.Tensor],
summaries: Dict = {},
writer: Optional[SummaryWriter] = None) -> NoReturn:
'''
write tf summaries showing in tensorboard.
'''
Expand Down
2 changes: 1 addition & 1 deletion rls/algos/base/ma_off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,5 @@ def _learn(self, function_dict: Dict = {}) -> NoReturn:
self.summaries[k].update(v)

# --------------------------------------写summary到tensorboard
self.write_training_summaries(self.global_step, self.summaries)
self.write_summaries(self.global_step, self.summaries)
# --------------------------------------
21 changes: 7 additions & 14 deletions rls/algos/base/ma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,14 @@ def _get_actions(self, obs, is_training: bool = True) -> Any:
'''
raise NotImplementedError

def writer_summary(self, global_step: Union[int, t.Tensor], summaries) -> NoReturn:
"""
record the data used to show in the tensorboard
"""
for i, summary in enumerate(summaries):
super().writer_summary(global_step, summaries=summary, writer=self.writers[i])

def write_training_summaries(self,
global_step: Union[int, t.Tensor],
summaries: Dict,
writer=None) -> NoReturn:
def write_summaries(self,
global_step: Union[int, t.Tensor],
summaries: Dict,
writer=None) -> NoReturn:
'''
write tf summaries showing in tensorboard.
'''
super().write_training_summaries(global_step, summaries=summaries.get('model', {}), writer=self.writer)
if 'model' in summaries.keys():
super().write_summaries(global_step, summaries=summaries.pop('model'), writer=self.writer)
for i, summary in summaries.items():
if i != 'model': # TODO: Optimization
super().write_training_summaries(global_step, summaries=summary, writer=self.writers[i])
super().write_summaries(global_step, summaries=summary, writer=self.writers[i])
4 changes: 2 additions & 2 deletions rls/algos/base/off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _learn(self, function_dict: Dict = {}) -> NoReturn:
# --------------------------------------

# --------------------------------------写summary到tensorboard
self.write_training_summaries(self.global_step, self.summaries)
self.write_summaries(self.global_step, self.summaries)
# --------------------------------------

def _apex_learn(self, function_dict: Dict, data: BatchExperiences, priorities) -> np.ndarray:
Expand All @@ -209,7 +209,7 @@ def _apex_learn(self, function_dict: Dict, data: BatchExperiences, priorities) -

self._target_params_update()
self.summaries.update(_summary)
self.write_training_summaries(self.global_step, self.summaries)
self.write_summaries(self.global_step, self.summaries)

return np.squeeze(td_error.numpy())

Expand Down
2 changes: 1 addition & 1 deletion rls/algos/base/on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ def _learn(self, function_dict: Dict) -> NoReturn:
self.summaries.update(summaries)
self.summaries.update(_summary)

self.write_training_summaries(self.train_step, self.summaries)
self.write_summaries(self.train_step, self.summaries)

self.data.clear()
2 changes: 1 addition & 1 deletion rls/algos/hierarchical/hiro.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def learn(self, **kwargs):
['LEARNING_RATE/high_actor_lr', self.high_actor_oplr.lr],
['LEARNING_RATE/high_critic_lr', self.high_critic_oplr.lr]
]))
self.write_training_summaries(self.global_step, self.summaries)
self.write_summaries(self.global_step, self.summaries)

@iTensor_oNumpy
def train_low(self, BATCH: Low_BatchExperiences):
Expand Down
Loading

0 comments on commit a3c989a

Please sign in to comment.