-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
51 lines (38 loc) · 1.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import json
import argparse
from unityagents import UnityEnvironment
import session
from noise import OrnsteinUhlenbeckNoise
from experience import ReplayBuffer
from agent import DDPGAgent
import utils
def main():
parser = argparse.ArgumentParser(description="Run Extended Q-Learning with given config")
parser.add_argument("-c",
"--config",
type=str,
metavar="",
required=True,
help="Config file name - file must be available as .json in ./configs")
args = parser.parse_args()
# load config files
with open(os.path.join(".", "configs", args.config), "r") as read_file:
config = json.load(read_file)
env = UnityEnvironment(file_name=os.path.join(*config["env_path"]))
noise = OrnsteinUhlenbeckNoise(config["n_actions"], config["mu"], config["theta"], config["sigma"], config["seed"])
replay_buffer = ReplayBuffer(config["buffer_size"], config["device"], config["seed"])
agent = DDPGAgent(config, noise, replay_buffer)
if config["run_training"]:
session.train(agent, env, config)
checkpoint_dir = os.path.join(".", *config["checkpoint_dir"], config["env_name"])
utils.save_state_dict(os.path.join(checkpoint_dir, "actor"), agent.actor.state_dict())
utils.save_state_dict(os.path.join(checkpoint_dir, "critic"), agent.critic.state_dict())
else:
checkpoint_dir = os.path.join(".", *config["checkpoint_dir"], config["env_name"])
agent.actor.load_state_dict(utils.load_latest_available_state_dict(os.path.join(checkpoint_dir, "actor", "*")))
agent.critic.load_state_dict(utils.load_latest_available_state_dict(os.path.join(checkpoint_dir, "critic", "*")))
session.evaluate(agent, env, num_test_runs=1)
env.close()
if __name__ == '__main__':
main()