Skip to content

Commit ada4d0f

Browse files
committed
Merge the wins and plays tables into one, storing flyweight objects.
It turns out that even at low thinking times that this is competitive, and at higher thinking times the reduced need to reallocate large dictionaries may dominate. Some example numbers of games (more is better) managed under 30 seconds and 300 seconds thinking time: old new ------ ------ 30s 31611 31545 300s 296359 306282
1 parent 5333d7e commit ada4d0f

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

mcts/uct.py

+42-42
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
from random import choice
66

77

8+
class Stat(object):
9+
__slots__ = ('value', 'visits')
10+
11+
def __init__(self, value=0, visits=0):
12+
self.value = value
13+
self.visits = visits
14+
15+
816
class MonteCarlo(object):
917
def __init__(self, board, **kwargs):
1018
self.board = board
1119
self.history = []
12-
self.wins = {}
13-
self.plays = {}
20+
self.stats = {}
1421

1522
self.max_depth = 0
1623
self.data = {}
@@ -66,9 +73,9 @@ def get_play(self):
6673
# Display the stats for each possible play.
6774
self.data['moves'] = sorted(
6875
({'move': p,
69-
'percent': 100 * self.wins.get((player, S), 0) / self.plays.get((player, S), 1),
70-
'wins': self.wins.get((player, S), 0),
71-
'plays': self.plays.get((player, S), 0)}
76+
'percent': 100 * self.stats[(player, S)].value / self.stats[(player, S)].visits,
77+
'wins': self.stats[(player, S)].value,
78+
'plays': self.stats[(player, S)].visits}
7279
for p, S in moves_states),
7380
key=lambda x: (x['percent'], x['plays']),
7481
reverse=True
@@ -78,9 +85,8 @@ def get_play(self):
7885

7986
# Pick the move with the highest percentage of wins.
8087
percent_wins, num_moves, move = max(
81-
(self.wins.get((player, S), 0) /
82-
self.plays.get((player, S), 1),
83-
self.plays.get((player, S), 0),
88+
(self.stats[(player, S)].value / self.stats[(player, S)].visits,
89+
self.stats[(player, S)].visits,
8490
p)
8591
for p, S in moves_states
8692
)
@@ -93,7 +99,7 @@ def run_simulation(self):
9399

94100
# A bit of an optimization here, so we have a local
95101
# variable lookup instead of an attribute access each loop.
96-
plays, wins = self.plays, self.wins
102+
stats = self.stats
97103

98104
visited_states = set()
99105
history_copy = self.history[:]
@@ -105,13 +111,13 @@ def run_simulation(self):
105111
legal = self.board.legal_plays(history_copy)
106112
moves_states = [(p, self.board.next_state(state, p)) for p in legal]
107113

108-
if all(plays.get((player, S)) for p, S in moves_states):
114+
if all((player, S) in stats for p, S in moves_states):
109115
# If we have stats on all of the legal moves here, use UCB1.
110116
log_total = log(
111-
sum(plays[(player, S)] for p, S in moves_states))
117+
sum(stats[(player, S)].visits for p, S in moves_states))
112118
value, move, state = max(
113-
((wins[(player, S)] / plays[(player, S)]) +
114-
self.C * sqrt(log_total / plays[(player, S)]), p, S)
119+
((stats[(player, S)].value / stats[(player, S)].visits) +
120+
self.C * sqrt(log_total / stats[(player, S)].visits), p, S)
115121
for p, S in moves_states
116122
)
117123
else:
@@ -122,10 +128,9 @@ def run_simulation(self):
122128

123129
# `player` here and below refers to the player
124130
# who moved into that particular state.
125-
if expand and (player, state) not in plays:
131+
if expand and (player, state) not in stats:
126132
expand = False
127-
plays[(player, state)] = 0
128-
wins[(player, state)] = 0
133+
stats[(player, state)] = Stat()
129134
if t > self.max_depth:
130135
self.max_depth = t
131136

@@ -137,20 +142,19 @@ def run_simulation(self):
137142
break
138143

139144
for player, state in visited_states:
140-
if (player, state) not in plays:
145+
if (player, state) not in stats:
141146
continue
142-
plays[(player, state)] += 1
147+
S = stats[(player, state)]
148+
S.visits += 1
143149
if player == winner:
144-
wins[(player, state)] += 1
150+
S.value += 1
145151

146152

147153
class ValueMonteCarlo(object):
148154
def __init__(self, board, **kwargs):
149155
self.board = board
150156
self.history = []
151-
152-
self.values = {}
153-
self.plays = {}
157+
self.stats = {}
154158

155159
self.max_depth = 0
156160
self.data = {}
@@ -175,9 +179,6 @@ def get_play(self):
175179
# Causes the AI to calculate the best move from the
176180
# current game state and return it.
177181

178-
self.values.clear()
179-
self.plays.clear()
180-
181182
self.max_depth = 0
182183
self.data = {}
183184

@@ -209,9 +210,9 @@ def get_play(self):
209210
# Display the stats for each possible play.
210211
self.data['moves'] = sorted(
211212
({'move': p,
212-
'average': self.values.get((player, S), 0) / self.plays.get((player, S), 1),
213-
'sum': self.values.get((player, S), 0),
214-
'plays': self.plays.get((player, S), 0)}
213+
'average': self.stats[(player, S)].value / self.stats[(player, S)].visits,
214+
'sum': self.stats[(player, S)].value,
215+
'plays': self.stats[(player, S)].visits}
215216
for p, S in moves_states),
216217
key=lambda x: (x['average'], x['plays']),
217218
reverse=True
@@ -221,9 +222,8 @@ def get_play(self):
221222

222223
# Pick the move with the highest average value.
223224
average, num_moves, move = max(
224-
(self.values.get((player, S), 0) /
225-
self.plays.get((player, S), 1),
226-
self.plays.get((player, S), 0),
225+
(self.stats[(player, S)].value / self.stats[(player, S)].visits,
226+
self.stats[(player, S)].visits,
227227
p)
228228
for p, S in moves_states
229229
)
@@ -236,7 +236,7 @@ def run_simulation(self):
236236

237237
# A bit of an optimization here, so we have a local
238238
# variable lookup instead of an attribute access each loop.
239-
plays, values = self.plays, self.values
239+
stats = self.stats
240240

241241
visited_states = set()
242242
history_copy = self.history[:]
@@ -248,13 +248,13 @@ def run_simulation(self):
248248
legal = self.board.legal_plays(history_copy)
249249
moves_states = [(p, self.board.next_state(state, p)) for p in legal]
250250

251-
if all(plays.get((player, S)) for p, S in moves_states):
251+
if all((player, S) in stats for p, S in moves_states):
252252
# If we have stats on all of the legal moves here, use UCB1.
253253
log_total = log(
254-
sum(plays[(player, S)] for p, S in moves_states))
254+
sum(stats[(player, S)].visits for p, S in moves_states))
255255
value, move, state = max(
256-
((values[(player, S)] / plays[(player, S)]) +
257-
self.C * sqrt(log_total / plays[(player, S)]), p, S)
256+
((stats[(player, S)].value / stats[(player, S)].visits) +
257+
self.C * sqrt(log_total / stats[(player, S)].visits), p, S)
258258
for p, S in moves_states
259259
)
260260
else:
@@ -265,10 +265,9 @@ def run_simulation(self):
265265

266266
# `player` here and below refers to the player
267267
# who moved into that particular state.
268-
if expand and (player, state) not in plays:
268+
if expand and (player, state) not in stats:
269269
expand = False
270-
plays[(player, state)] = 0
271-
values[(player, state)] = 0
270+
stats[(player, state)] = Stat()
272271
if t > self.max_depth:
273272
self.max_depth = t
274273

@@ -281,11 +280,12 @@ def run_simulation(self):
281280

282281
player_values = {}
283282
for player, state in visited_states:
284-
if (player, state) not in plays:
283+
if (player, state) not in stats:
285284
continue
286285
if player not in player_values:
287286
player_values[player] = self.board.end_value(history_copy, player)
288287

289-
plays[(player, state)] += 1
288+
S = stats[(player, state)]
289+
S.visits += 1
290290
if player_values[player] is not None:
291-
values[(player, state)] += player_values[player]
291+
S.value += player_values[player]

0 commit comments

Comments
 (0)