-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_es.py
116 lines (98 loc) · 3.26 KB
/
run_es.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import sys
import signal
import subprocess
import argparse
import random
import yaml
from mpi4py import MPI
import torch
import numpy as np
import builder
from loops.worker_func import run_rollout
# keyboard interrupt(ctrl + c) handler.
def sigterm_handler(signal, frame):
print("abort all processes")
comm.Abort()
raise KeyboardInterrupt
signal.signal(signal.SIGINT, sigterm_handler)
# MPI setting.
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
def mpi_fork(n):
"""Re-launches the current script with workers
Returns "parent" for original parent, "child" for MPI children
(from https://github.com/garymcintire/mpi_util/)
"""
if n <= 1:
return "child"
if os.getenv("IN_MPI") is None:
env = os.environ.copy()
env.update(MKL_NUM_THREADS="1", OMP_NUM_THREADS="1", IN_MPI="1")
print(["mpiexec", "-n", str(n), sys.executable] + sys.argv)
subprocess.check_call(["mpiexec", "-n", str(n), sys.executable] + ["-u"] + sys.argv, env=env)
return "parent"
else:
global nworkers, rank
nworkers = comm.Get_size()
rank = comm.Get_rank()
print("assigning the rank and nworkers", nworkers, rank)
return "child"
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def main(config, seed, n_workers, generation_num, eval_ep_num, log, save_model_period):
set_seed(seed)
env = builder.build_env(config["env"], rank)
agent_ids = env.get_agent_ids()
env_name = env.name
env.close()
del env
network = builder.build_network(config["network"])
loop = builder.build_loop(
config,
network,
agent_ids,
env_name,
generation_num,
n_workers,
eval_ep_num,
log,
save_model_period,
)
loop.run()
def worker(seed, env_cfg, network_cfg):
set_seed(seed)
env = builder.build_env(env_cfg, rank)
network = builder.build_network(network_cfg)
run_rollout(env, network)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cfg-path", type=str, default="conf/neat/cartpole.yaml", help="config file to run.")
parser.add_argument("--seed", type=int, default=0, help="random seed.")
parser.add_argument("--n-workers", type=int, default=11)
parser.add_argument("--generation-num", type=int, default=10000, help="max number of generation iteration.")
parser.add_argument("--eval-ep-num", type=int, default=5, help="number of model evaluaion per iteration.")
parser.add_argument("--log", action="store_true", help="wandb log")
parser.add_argument("--save-model-period", type=int, default=10, help="save model for every n iteration.")
args = parser.parse_args()
if "parent" == mpi_fork(args.n_workers + 1):
print("abort all processes.")
comm.Abort()
sys.exit()
with open(args.cfg_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
f.close()
if rank == 0:
main(
config,
args.seed,
args.n_workers,
args.generation_num,
args.eval_ep_num,
args.log,
args.save_model_period,
)
else:
worker(args.seed, config["env"], config["network"])