-
Notifications
You must be signed in to change notification settings - Fork 13
/
run_mujoco.py
61 lines (51 loc) · 2.36 KB
/
run_mujoco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# !/usr/bin/env python
from baselines.common import set_global_seeds, tf_util as U
from baselines import bench
import os.path as osp
import gym, logging
import pdb
from baselines import logger
import sys
def train(env_id, num_timesteps, seed, num_options,app, saves ,wsaves, epoch,dc):
from baselines.ppo1 import mlp_policy, pposgd_simple
U.make_session(num_cpu=1).__enter__()
set_global_seeds(seed)
env = gym.make(env_id)
def policy_fn(name, ob_space, ac_space):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
hid_size=64, num_hid_layers=2, num_options=num_options, dc=dc)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), "monitor.json"))
env.seed(seed)
gym.logger.setLevel(logging.WARN)
if num_options ==1:
optimsize=64
elif num_options ==2:
optimsize=32
else:
print("Only two options or primitive actions is currently supported.")
sys.exit()
pposgd_simple.learn(env, policy_fn,
max_timesteps=num_timesteps,
timesteps_per_batch=2048,
clip_param=0.2, entcoeff=0.0,
optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=optimsize,
gamma=0.99, lam=0.95, schedule='constant', num_options=num_options,
app=app, saves=saves, wsaves=wsaves, epoch=epoch, seed=seed,dc=dc
)
env.close()
def main():
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--env', help='environment ID', default='Hopper-v1')
parser.add_argument('--seed', help='RNG seed', type=int, default=1)
parser.add_argument('--opt', help='number of options', type=int, default=2)
parser.add_argument('--app', help='Append to folder name', type=str, default='')
parser.add_argument('--saves', dest='saves', action='store_true', default=False)
parser.add_argument('--wsaves', dest='wsaves', action='store_true', default=False)
parser.add_argument('--epoch', help='Epoch', type=int, default=-1)
parser.add_argument('--dc', type=float, default=0.)
args = parser.parse_args()
train(args.env, num_timesteps=1e6, seed=args.seed, num_options=args.opt, app=args.app, saves=args.saves, wsaves=args.wsaves, epoch=args.epoch,dc=args.dc)
if __name__ == '__main__':
main()