-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
105 lines (84 loc) · 3.06 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import torch
import wandb
from level_replay.algo.policy import DDQN
from train_rainbow import eval_policy, generate_seeds
import argparse
from distutils.util import strtobool
def construct_class_from_dict(d):
class Args:
def __init__(self, dictionary):
for k, v in dictionary.items():
setattr(self, k, v)
args = Args(d)
return args
def evaluate(args):
dictionary = torch.load(args.model_path)
model_state_dict = dictionary["model_state_dict"]
arg_dict = dictionary["args"]
model_args = construct_class_from_dict(arg_dict)
wandb.init(
settings=wandb.Settings(start_method="fork"),
project="off-policy-procgen",
entity="ucl-dark",
config=vars(model_args),
tags=["ddqn", "procgen"] + (args.wandb_tags.split(",") if args.wandb_tags else []),
group="Evaluations",
)
agent = DDQN(model_args)
agent.Q.load_state_dict(model_state_dict)
if args.test:
eval_episode_rewards = eval_policy(model_args, agent, model_args.num_test_seeds)
wandb.log(
{
"Test Evaluation Returns": np.mean(eval_episode_rewards),
}
)
if args.each_train_level:
train_eval_episode_rewards = []
for seed in range(model_args.num_train_seeds):
rewards = eval_policy(
model_args, agent, model_args.num_test_seeds, start_level=seed, num_levels=1, seeds=[seed]
)
train_eval_episode_rewards.append([seed, np.mean(rewards)])
table = wandb.Table(data=train_eval_episode_rewards, columns=["Train Level", "Evaluation Rewards"])
wandb.log(
{
"Train Evaluations for Each Training Level": wandb.plot.bar(
table,
"Train Level",
"Evaluation Rewards",
title="Train Evaluations for Each Training Level",
)
}
)
else:
seeds = generate_seeds(args.num_train_seeds)
train_eval_episode_rewards = eval_policy(
model_args,
agent,
model_args.num_test_seeds,
start_level=0,
num_levels=model_args.num_train_seeds,
seeds=seeds,
)
wandb.log(
{
"Train Evaluation Returns": np.mean(train_eval_episode_rewards),
}
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DQN")
parser.add_argument("--model_path", default="model/model.tar", help="Path to pre-trained model")
parser.add_argument(
"--test", type=lambda x: bool(strtobool(x)), default=True, help="Whether to evaluate on unseen levels"
)
parser.add_argument(
"--each_train_level",
type=lambda x: bool(strtobool(x)),
default=False,
help="Whether to get score for each train level",
)
parser.add_argument("--no_cuda", type=lambda x: bool(strtobool(x)), default=False, help="disables gpu")
args = parser.parse_args()
evaluate(args)