-
Notifications
You must be signed in to change notification settings - Fork 0
/
Bootstrapper.java
153 lines (124 loc) · 4.09 KB
/
Bootstrapper.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import engine.*;
import agent.*;
import agent.QL.*;
import java.util.List;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
public class Bootstrapper
{
private static Bootstrapper instance = new Bootstrapper();
private static Game game;
public void startGame(Game game)
{
System.out.println("Creating game simulator.");
game.run();
}
public static Bootstrapper getInstance()
{
return instance;
}
public static void printGameInfo(int numRounds, int interval, int size, List<String> agentParams)
{
// num rounds
System.out.print(numRounds + " ");
// results interval
System.out.print(interval + " ");
// grid size:
System.out.print(size + " ");
// agent 1:
System.out.print(agentParams.get(0) + " ");
// agent 2:
System.out.println(agentParams.get(1) + " ");
}
public static void main(String[] args)
{
int num_rounds = 2000;
List<String> agentParams = new ArrayList<String>();
List<String> filenames = new ArrayList<String>();
List<Agent> agents = new ArrayList<Agent>();
List<Float> lrs = new ArrayList<Float>();
List<Float> discounts = new ArrayList<Float>();
List<Float> epsilons = new ArrayList<Float>();
int size = Grid.default_size;
int interval = 100;
try
{
size = Integer.parseInt(args[0]);
}
catch(Exception e)
{
;
}
Bootstrapper bootstrapper = Bootstrapper.getInstance();
Grid grid = new Grid(size);
String className = "";
// define agents in agentParams
// agent 1
agentParams.add("Random");
filenames.add("./Q1");
lrs.add(0.6f);
discounts.add(0.5f);
epsilons.add(0.1f);
// agent 2
agentParams.add("QLAgent");
filenames.add("./Q3");
lrs.add(0.6f);
discounts.add(0.5f);
epsilons.add(0.1f);
for (int i = 0; i < 2; i++)
{
Agent agent = null;
className = agentParams.get(i);
if (className.equals("QLAgent"))
{
StateMatrix matrix = null;
float learningRate = lrs.get(i);
float discountFactor = discounts.get(i);
float epsilon = epsilons.get(i);
File f = new File(filenames.get(i));
System.out.println(filenames.get(i));
if (f.exists())
{
try
{
matrix = StateMatrix.load_table(filenames.get(i));
System.out.println("Loaded matrix...");
}
catch (IOException e)
{
System.out.println("File not found." + e);
return;
}
}
else
{
try
{
matrix = StateMatrix.create_table(filenames.get(i), grid.get_max_id() + 1, grid.get_size());
System.out.println("Creating matrix...");
}
catch (OutOfMemoryError e2)
{
System.out.println("Out of memory. Terminating.");
return;
}
}
agent = new QLAgent(matrix, discountFactor, learningRate, epsilon, filenames.get(i));
}
else if (className.equals("Simple"))
{
agent = new Simple();
}
else if (className.equals("Random"))
{
agent = new RandomPlayer();
}
agents.add(agent);
}
printGameInfo(num_rounds, interval, size, agentParams);
grid = new Grid(size);
game = new Game(grid, agents.get(0), agents.get(1), num_rounds, interval);
bootstrapper.startGame(game);
}
}