@@ -99,51 +99,45 @@ def run_simulation(self):
99
99
visited_states = []
100
100
history_copy = self .history [:]
101
101
state = history_copy [- 1 ]
102
- player = self .board .current_player (state )
103
102
104
103
expand = True
105
104
for t in range (1 , self .max_actions + 1 ):
106
105
legal = self .board .legal_actions (history_copy )
107
106
actions_states = [(a , self .board .next_state (history_copy , a )) for a in legal ]
108
107
109
108
if expand :
110
- if not all ((player , S ) in stats for a , S in actions_states ):
111
- # `player` here and below refers to the player who
112
- # moved into that particular state.
113
- stats .update (((player , S ), Stat ()) for a , S in actions_states
114
- if (player , S ) not in stats )
109
+ if not all (S in stats for a , S in actions_states ):
110
+ stats .update ((S , Stat ()) for a , S in actions_states if S not in stats )
115
111
expand = False
116
112
if t > self .max_depth :
117
113
self .max_depth = t
118
114
119
115
# If we have stats on all of the legal actions here, use UCB1.
120
116
log_total = log (
121
- sum (stats [( player , S ) ].visits for a , S in actions_states ) or 1 )
117
+ sum (stats [S ].visits for a , S in actions_states ) or 1 )
122
118
values_actions = [
123
119
(a , S , (e .value / (e .visits or 1 )) + self .C * sqrt (log_total / (e .visits or 1 )))
124
- for a , S , e in ((a , S , stats [( player , S ) ]) for a , S in actions_states )
120
+ for a , S , e in ((a , S , stats [S ]) for a , S in actions_states )
125
121
]
126
122
max_value = max (v for _ , _ , v in values_actions )
127
123
# Filter down to only those actions with maximum value under UCB1.
128
124
actions_states = [(a , S ) for a , S , v in values_actions if v == max_value ]
129
125
130
126
action , state = choice (actions_states )
131
- visited_states .append (( player , state ) )
127
+ visited_states .append (state )
132
128
history_copy .append (state )
133
- # Who is the next player to take an action?
134
- player = self .board .current_player (state )
135
129
136
130
if self .board .is_ended (history_copy ):
137
131
break
138
132
139
133
# Back-propagation
140
134
end_values = self .end_values (history_copy )
141
- for player , state in visited_states :
142
- if ( player , state ) not in stats :
135
+ for state in visited_states :
136
+ if state not in stats :
143
137
continue
144
- S = stats [( player , state ) ]
138
+ S = stats [state ]
145
139
S .visits += 1
146
- S .value += end_values [player ]
140
+ S .value += end_values [self . board . previous_player ( state ) ]
147
141
148
142
149
143
class UCTWins (UCT ):
@@ -158,9 +152,9 @@ def calculate_action_values(self, history, player, legal):
158
152
actions_states = ((a , self .board .next_state (history , a )) for a in legal )
159
153
return sorted (
160
154
({'action' : a ,
161
- 'percent' : 100 * self .stats [( player , S ) ].value / (self .stats [( player , S ) ].visits or 1 ),
162
- 'wins' : self .stats [( player , S ) ].value ,
163
- 'plays' : self .stats [( player , S ) ].visits }
155
+ 'percent' : 100 * self .stats [S ].value / (self .stats [S ].visits or 1 ),
156
+ 'wins' : self .stats [S ].value ,
157
+ 'plays' : self .stats [S ].visits }
164
158
for a , S in actions_states ),
165
159
key = lambda x : (x ['percent' ], x ['plays' ]),
166
160
reverse = True
@@ -179,9 +173,9 @@ def calculate_action_values(self, history, player, legal):
179
173
actions_states = ((a , self .board .next_state (history , a )) for a in legal )
180
174
return sorted (
181
175
({'action' : a ,
182
- 'average' : self .stats [( player , S ) ].value / (self .stats [( player , S ) ].visits or 1 ),
183
- 'sum' : self .stats [( player , S ) ].value ,
184
- 'plays' : self .stats [( player , S ) ].visits }
176
+ 'average' : self .stats [S ].value / (self .stats [S ].visits or 1 ),
177
+ 'sum' : self .stats [S ].value ,
178
+ 'plays' : self .stats [S ].visits }
185
179
for a , S in actions_states ),
186
180
key = lambda x : (x ['average' ], x ['plays' ]),
187
181
reverse = True
0 commit comments