-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
57 lines (47 loc) · 2.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
from utils import loggers
from Agents.AgentConfigs import *
from Agents.AgentBuilder import build_agent
from Enviroment.EnvBuilder import get_env_builder, get_env_goal
import train
from opt import *
from gym import wrappers
def get_train_function(agent_name):
if agent_name == "PPOParallel":
return train.train_agent_multi_env
else:
return train.train_agent
def get_logger(logger_type, log_frequency, logdir):
if logger_type == 'plt':
constructor = loggers.plt_logger
elif logger_type == 'tensorboard':
constructor = loggers.TB_logger
else:
constructor = loggers.logger
return constructor(log_frequency, logdir)
if __name__ == '__main__':
# random.seed(SEED)
# np.random.seed(SEED)
# torch.manual_seed(SEED)
env_builder = get_env_builder(ENV_NAME)
hp = get_agent_configs(AGENT_NAME, ENV_NAME)
agent = build_agent(AGENT_NAME, env_builder(), hp)
if WEIGHTS_FILE:
agent.load_state(WEIGHTS_FILE)
train_dir = os.path.join(TRAIN_ROOT, ENV_NAME, agent.name)
if TRAIN:
logger = get_logger(LOGGER_TYPE, log_frequency=LOG_FREQUENCY, logdir=train_dir)
agent.set_reporter(logger)
progress_maneger = train.train_progress_manager(train_dir, get_env_goal(ENV_NAME), SCORE_SCOPE, logger,
checkpoint_steps=CKP_STEP, train_episodes=TRAIN_EPISODES,
temporal_frequency=TEMPORAL_FREQ)
train_function = get_train_function(AGENT_NAME)
train_function(env_builder, agent, progress_maneger, test_frequency=TEST_FREQ, test_episodes=TEST_EPISODES,
save_videos=SAVE_VIDEOS)
else:
# Test
env = env_builder(test_config=True)
env = wrappers.Monitor(env, os.path.join(train_dir, "test"),
video_callable=lambda episode_id: True, force=True)
score = train.test(env, agent, TEST_EPISODES, render=True)
print("Avg reward over %d episodes: %f"%(TEST_EPISODES, score))