Skip to content

Commit 40c1c03

Browse files
committed
Refactor into a UCT base class and two child classes
The UCT base class contains the common logic of doing the Monte Carlo Tree Search, while the subclasses customize which values are being tracked (wins vs. points) and how they are output.
1 parent 01374f2 commit 40c1c03

File tree

2 files changed

+36
-154
lines changed

2 files changed

+36
-154
lines changed

mcts/uct.py

+34-152
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, value=0, visits=0):
1313
self.visits = visits
1414

1515

16-
class MonteCarlo(object):
16+
class UCT(object):
1717
def __init__(self, board, **kwargs):
1818
self.board = board
1919
self.history = []
@@ -69,30 +69,13 @@ def get_action(self):
6969
print self.data['games'], self.data['time']
7070
print "Maximum depth searched:", self.max_depth
7171

72-
actions_states = [(p, self.board.next_state(state, p)) for p in legal]
73-
74-
# Display the stats for each possible action.
75-
self.data['actions'] = sorted(
76-
({'action': p,
77-
'percent': 100 * self.stats[(player, S)].value / self.stats[(player, S)].visits,
78-
'wins': self.stats[(player, S)].value,
79-
'plays': self.stats[(player, S)].visits}
80-
for p, S in actions_states),
81-
key=lambda x: (x['percent'], x['plays']),
82-
reverse=True
83-
)
72+
# Store and display the stats for each possible action.
73+
self.data['actions'] = self.calculate_action_values(state, player, legal)
8474
for m in self.data['actions']:
85-
print "{action}: {percent:.2f}% ({wins} / {plays})".format(**m)
86-
87-
# Pick the action with the highest percentage of wins.
88-
percent_wins, num_actions, action = max(
89-
(self.stats[(player, S)].value / self.stats[(player, S)].visits,
90-
self.stats[(player, S)].visits,
91-
p)
92-
for p, S in actions_states
93-
)
75+
print self.action_template.format(**m)
9476

95-
return action
77+
# Pick the action with the highest average value.
78+
return self.data['actions'][0]['action']
9679

9780
def run_simulation(self):
9881
# Plays out a "random" game from the current position,
@@ -138,78 +121,49 @@ def run_simulation(self):
138121
visited_states.add((player, state))
139122

140123
player = self.board.current_player(state)
141-
winner = self.board.winner(history_copy)
142-
if winner:
124+
if self.board.is_ended(history_copy):
143125
break
144126

127+
# Back-propagation
128+
end_values = self.end_values(history_copy)
145129
for player, state in visited_states:
146130
if (player, state) not in stats:
147131
continue
148132
S = stats[(player, state)]
149133
S.visits += 1
150-
if player == winner:
151-
S.value += 1
152-
153-
154-
class ValueMonteCarlo(object):
155-
def __init__(self, board, **kwargs):
156-
self.board = board
157-
self.history = []
158-
self.stats = {}
159-
160-
self.max_depth = 0
161-
self.data = {}
162-
163-
self.calculation_time = float(kwargs.get('time', 30))
164-
self.max_actions = int(kwargs.get('max_actions', 1000))
165-
166-
# Exploration constant, increase for more exploratory actions,
167-
# decrease to prefer actions with known higher win rates.
168-
self.C = float(kwargs.get('C', 1.4))
134+
S.value += end_values[player]
169135

170-
def update(self, state):
171-
self.history.append(state)
172136

173-
def display(self, state, action):
174-
return self.board.display(state, action)
137+
class UCTWins(UCT):
138+
action_template = "{action}: {percent:.2f}% ({wins} / {plays})"
175139

176-
def winner_message(self, msg):
177-
return self.board.winner_message(msg)
178-
179-
def get_action(self):
180-
# Causes the AI to calculate the best action from the
181-
# current game state and return it.
182-
183-
self.max_depth = 0
184-
self.data = {}
140+
def __init__(self, board, **kwargs):
141+
super(UCTWins, self).__init__(board, **kwargs)
142+
self.end_values = board.win_values
185143

186-
state = self.history[-1]
187-
player = self.board.current_player(state)
188-
legal = self.board.legal_actions(self.history[:])
144+
def calculate_action_values(self, state, player, legal):
145+
actions_states = ((p, self.board.next_state(state, p)) for p in legal)
146+
return sorted(
147+
({'action': p,
148+
'percent': 100 * self.stats[(player, S)].value / self.stats[(player, S)].visits,
149+
'wins': self.stats[(player, S)].value,
150+
'plays': self.stats[(player, S)].visits}
151+
for p, S in actions_states),
152+
key=lambda x: (x['percent'], x['plays']),
153+
reverse=True
154+
)
189155

190-
# Bail out early if there is no real choice to be made.
191-
if not legal:
192-
return
193-
if len(legal) == 1:
194-
return legal[0]
195156

196-
games = 0
197-
begin = time.time()
198-
while time.time() - begin < self.calculation_time:
199-
self.run_simulation()
200-
games += 1
157+
class UCTValues(UCT):
158+
action_template = "{action}: {average:.1f} ({sum} / {plays})"
201159

202-
# Display the number of calls of `run_simulation` and the
203-
# time elapsed.
204-
self.data.update(games=games, max_depth=self.max_depth,
205-
time=str(time.time() - begin))
206-
print self.data['games'], self.data['time']
207-
print "Maximum depth searched:", self.max_depth
208-
209-
actions_states = [(p, self.board.next_state(state, p)) for p in legal]
160+
def __init__(self, board, **kwargs):
161+
super(UCTValues, self).__init__(board, **kwargs)
162+
self.end_values = board.points_values
210163

211-
# Display the stats for each possible action.
212-
self.data['actions'] = sorted(
164+
def calculate_action_values(self, state, player, legal):
165+
actions_states = ((p, self.board.next_state(state, p)) for p in legal)
166+
return sorted(
213167
({'action': p,
214168
'average': self.stats[(player, S)].value / self.stats[(player, S)].visits,
215169
'sum': self.stats[(player, S)].value,
@@ -218,75 +172,3 @@ def get_action(self):
218172
key=lambda x: (x['average'], x['plays']),
219173
reverse=True
220174
)
221-
for m in self.data['actions']:
222-
print "{action}: {average:.1f} ({sum} / {plays})".format(**m)
223-
224-
# Pick the action with the highest average value.
225-
average, num_actions, action = max(
226-
(self.stats[(player, S)].value / self.stats[(player, S)].visits,
227-
self.stats[(player, S)].visits,
228-
p)
229-
for p, S in actions_states
230-
)
231-
232-
return action
233-
234-
def run_simulation(self):
235-
# Plays out a "random" game from the current position,
236-
# then updates the statistics tables with the result.
237-
238-
# A bit of an optimization here, so we have a local
239-
# variable lookup instead of an attribute access each loop.
240-
stats = self.stats
241-
242-
visited_states = set()
243-
history_copy = self.history[:]
244-
state = history_copy[-1]
245-
player = self.board.current_player(state)
246-
247-
expand = True
248-
for t in xrange(1, self.max_actions + 1):
249-
legal = self.board.legal_actions(history_copy)
250-
actions_states = [(p, self.board.next_state(state, p)) for p in legal]
251-
252-
if all((player, S) in stats for p, S in actions_states):
253-
# If we have stats on all of the legal actions here, use UCB1.
254-
log_total = log(
255-
sum(stats[(player, S)].visits for p, S in actions_states))
256-
value, action, state = max(
257-
((stats[(player, S)].value / stats[(player, S)].visits) +
258-
self.C * sqrt(log_total / stats[(player, S)].visits), p, S)
259-
for p, S in actions_states
260-
)
261-
else:
262-
# Otherwise, just make an arbitrary decision.
263-
action, state = choice(actions_states)
264-
265-
history_copy.append(state)
266-
267-
# `player` here and below refers to the player
268-
# who moved into that particular state.
269-
if expand and (player, state) not in stats:
270-
expand = False
271-
stats[(player, state)] = Stat()
272-
if t > self.max_depth:
273-
self.max_depth = t
274-
275-
visited_states.add((player, state))
276-
277-
player = self.board.current_player(state)
278-
winner = self.board.winner(history_copy)
279-
if winner:
280-
break
281-
282-
player_values = {}
283-
for player, state in visited_states:
284-
if (player, state) not in stats:
285-
continue
286-
if player not in player_values:
287-
player_values[player] = self.board.end_value(history_copy, player)
288-
289-
S = stats[(player, state)]
290-
S.visits += 1
291-
if player_values[player] is not None:
292-
S.value += player_values[player]

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
author_email='[email protected]',
88
packages=['mcts'],
99
entry_points={
10-
'jrb_board.players': ['jrb.mcts.uct = mcts.uct:MonteCarlo',
11-
'jrb.mcts.uctv = mcts.uct:ValueMonteCarlo'],
10+
'jrb_board.players': ['jrb.mcts.uct = mcts.uct:UCTWins',
11+
'jrb.mcts.uctv = mcts.uct:UCTValues'],
1212
},
1313
license='LICENSE',
1414
description="An implementation of UCT Monte Carlo Tree Search.",

0 commit comments

Comments
 (0)