Skip to content

Commit da64100

Browse files
committed
Calculate the current player less often during simulations
Benchmark: under pypy playing Chong with a 30 second think time, before had less than 48k simulations, after had over 52k.
1 parent d18fb56 commit da64100

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*~
2+
*.pyc
3+
*.egg-info

mcts/uct.py

+15-21
Original file line numberDiff line numberDiff line change
@@ -99,51 +99,45 @@ def run_simulation(self):
9999
visited_states = []
100100
history_copy = self.history[:]
101101
state = history_copy[-1]
102-
player = self.board.current_player(state)
103102

104103
expand = True
105104
for t in range(1, self.max_actions + 1):
106105
legal = self.board.legal_actions(history_copy)
107106
actions_states = [(a, self.board.next_state(history_copy, a)) for a in legal]
108107

109108
if expand:
110-
if not all((player, S) in stats for a, S in actions_states):
111-
# `player` here and below refers to the player who
112-
# moved into that particular state.
113-
stats.update(((player, S), Stat()) for a, S in actions_states
114-
if (player, S) not in stats)
109+
if not all(S in stats for a, S in actions_states):
110+
stats.update((S, Stat()) for a, S in actions_states if S not in stats)
115111
expand = False
116112
if t > self.max_depth:
117113
self.max_depth = t
118114

119115
# If we have stats on all of the legal actions here, use UCB1.
120116
log_total = log(
121-
sum(stats[(player, S)].visits for a, S in actions_states) or 1)
117+
sum(stats[S].visits for a, S in actions_states) or 1)
122118
values_actions = [
123119
(a, S, (e.value / (e.visits or 1)) + self.C * sqrt(log_total / (e.visits or 1)))
124-
for a, S, e in ((a, S, stats[(player, S)]) for a, S in actions_states)
120+
for a, S, e in ((a, S, stats[S]) for a, S in actions_states)
125121
]
126122
max_value = max(v for _, _, v in values_actions)
127123
# Filter down to only those actions with maximum value under UCB1.
128124
actions_states = [(a, S) for a, S, v in values_actions if v == max_value]
129125

130126
action, state = choice(actions_states)
131-
visited_states.append((player, state))
127+
visited_states.append(state)
132128
history_copy.append(state)
133-
# Who is the next player to take an action?
134-
player = self.board.current_player(state)
135129

136130
if self.board.is_ended(history_copy):
137131
break
138132

139133
# Back-propagation
140134
end_values = self.end_values(history_copy)
141-
for player, state in visited_states:
142-
if (player, state) not in stats:
135+
for state in visited_states:
136+
if state not in stats:
143137
continue
144-
S = stats[(player, state)]
138+
S = stats[state]
145139
S.visits += 1
146-
S.value += end_values[player]
140+
S.value += end_values[self.board.previous_player(state)]
147141

148142

149143
class UCTWins(UCT):
@@ -158,9 +152,9 @@ def calculate_action_values(self, history, player, legal):
158152
actions_states = ((a, self.board.next_state(history, a)) for a in legal)
159153
return sorted(
160154
({'action': a,
161-
'percent': 100 * self.stats[(player, S)].value / (self.stats[(player, S)].visits or 1),
162-
'wins': self.stats[(player, S)].value,
163-
'plays': self.stats[(player, S)].visits}
155+
'percent': 100 * self.stats[S].value / (self.stats[S].visits or 1),
156+
'wins': self.stats[S].value,
157+
'plays': self.stats[S].visits}
164158
for a, S in actions_states),
165159
key=lambda x: (x['percent'], x['plays']),
166160
reverse=True
@@ -179,9 +173,9 @@ def calculate_action_values(self, history, player, legal):
179173
actions_states = ((a, self.board.next_state(history, a)) for a in legal)
180174
return sorted(
181175
({'action': a,
182-
'average': self.stats[(player, S)].value / (self.stats[(player, S)].visits or 1),
183-
'sum': self.stats[(player, S)].value,
184-
'plays': self.stats[(player, S)].visits}
176+
'average': self.stats[S].value / (self.stats[S].visits or 1),
177+
'sum': self.stats[S].value,
178+
'plays': self.stats[S].visits}
185179
for a, S in actions_states),
186180
key=lambda x: (x['average'], x['plays']),
187181
reverse=True

0 commit comments

Comments
 (0)