@@ -13,7 +13,7 @@ def __init__(self, value=0, visits=0):
13
13
self .visits = visits
14
14
15
15
16
- class MonteCarlo (object ):
16
+ class UCT (object ):
17
17
def __init__ (self , board , ** kwargs ):
18
18
self .board = board
19
19
self .history = []
@@ -69,30 +69,13 @@ def get_action(self):
69
69
print self .data ['games' ], self .data ['time' ]
70
70
print "Maximum depth searched:" , self .max_depth
71
71
72
- actions_states = [(p , self .board .next_state (state , p )) for p in legal ]
73
-
74
- # Display the stats for each possible action.
75
- self .data ['actions' ] = sorted (
76
- ({'action' : p ,
77
- 'percent' : 100 * self .stats [(player , S )].value / self .stats [(player , S )].visits ,
78
- 'wins' : self .stats [(player , S )].value ,
79
- 'plays' : self .stats [(player , S )].visits }
80
- for p , S in actions_states ),
81
- key = lambda x : (x ['percent' ], x ['plays' ]),
82
- reverse = True
83
- )
72
+ # Store and display the stats for each possible action.
73
+ self .data ['actions' ] = self .calculate_action_values (state , player , legal )
84
74
for m in self .data ['actions' ]:
85
- print "{action}: {percent:.2f}% ({wins} / {plays})" .format (** m )
86
-
87
- # Pick the action with the highest percentage of wins.
88
- percent_wins , num_actions , action = max (
89
- (self .stats [(player , S )].value / self .stats [(player , S )].visits ,
90
- self .stats [(player , S )].visits ,
91
- p )
92
- for p , S in actions_states
93
- )
75
+ print self .action_template .format (** m )
94
76
95
- return action
77
+ # Pick the action with the highest average value.
78
+ return self .data ['actions' ][0 ]['action' ]
96
79
97
80
def run_simulation (self ):
98
81
# Plays out a "random" game from the current position,
@@ -138,78 +121,49 @@ def run_simulation(self):
138
121
visited_states .add ((player , state ))
139
122
140
123
player = self .board .current_player (state )
141
- winner = self .board .winner (history_copy )
142
- if winner :
124
+ if self .board .is_ended (history_copy ):
143
125
break
144
126
127
+ # Back-propagation
128
+ end_values = self .end_values (history_copy )
145
129
for player , state in visited_states :
146
130
if (player , state ) not in stats :
147
131
continue
148
132
S = stats [(player , state )]
149
133
S .visits += 1
150
- if player == winner :
151
- S .value += 1
152
-
153
-
154
- class ValueMonteCarlo (object ):
155
- def __init__ (self , board , ** kwargs ):
156
- self .board = board
157
- self .history = []
158
- self .stats = {}
159
-
160
- self .max_depth = 0
161
- self .data = {}
162
-
163
- self .calculation_time = float (kwargs .get ('time' , 30 ))
164
- self .max_actions = int (kwargs .get ('max_actions' , 1000 ))
165
-
166
- # Exploration constant, increase for more exploratory actions,
167
- # decrease to prefer actions with known higher win rates.
168
- self .C = float (kwargs .get ('C' , 1.4 ))
134
+ S .value += end_values [player ]
169
135
170
- def update (self , state ):
171
- self .history .append (state )
172
136
173
- def display ( self , state , action ):
174
- return self . board . display ( state , action )
137
+ class UCTWins ( UCT ):
138
+ action_template = "{action}: {percent:.2f}% ({wins} / {plays})"
175
139
176
- def winner_message (self , msg ):
177
- return self .board .winner_message (msg )
178
-
179
- def get_action (self ):
180
- # Causes the AI to calculate the best action from the
181
- # current game state and return it.
182
-
183
- self .max_depth = 0
184
- self .data = {}
140
+ def __init__ (self , board , ** kwargs ):
141
+ super (UCTWins , self ).__init__ (board , ** kwargs )
142
+ self .end_values = board .win_values
185
143
186
- state = self .history [- 1 ]
187
- player = self .board .current_player (state )
188
- legal = self .board .legal_actions (self .history [:])
144
+ def calculate_action_values (self , state , player , legal ):
145
+ actions_states = ((p , self .board .next_state (state , p )) for p in legal )
146
+ return sorted (
147
+ ({'action' : p ,
148
+ 'percent' : 100 * self .stats [(player , S )].value / self .stats [(player , S )].visits ,
149
+ 'wins' : self .stats [(player , S )].value ,
150
+ 'plays' : self .stats [(player , S )].visits }
151
+ for p , S in actions_states ),
152
+ key = lambda x : (x ['percent' ], x ['plays' ]),
153
+ reverse = True
154
+ )
189
155
190
- # Bail out early if there is no real choice to be made.
191
- if not legal :
192
- return
193
- if len (legal ) == 1 :
194
- return legal [0 ]
195
156
196
- games = 0
197
- begin = time .time ()
198
- while time .time () - begin < self .calculation_time :
199
- self .run_simulation ()
200
- games += 1
157
+ class UCTValues (UCT ):
158
+ action_template = "{action}: {average:.1f} ({sum} / {plays})"
201
159
202
- # Display the number of calls of `run_simulation` and the
203
- # time elapsed.
204
- self .data .update (games = games , max_depth = self .max_depth ,
205
- time = str (time .time () - begin ))
206
- print self .data ['games' ], self .data ['time' ]
207
- print "Maximum depth searched:" , self .max_depth
208
-
209
- actions_states = [(p , self .board .next_state (state , p )) for p in legal ]
160
+ def __init__ (self , board , ** kwargs ):
161
+ super (UCTValues , self ).__init__ (board , ** kwargs )
162
+ self .end_values = board .points_values
210
163
211
- # Display the stats for each possible action.
212
- self .data ['actions' ] = sorted (
164
+ def calculate_action_values (self , state , player , legal ):
165
+ actions_states = ((p , self .board .next_state (state , p )) for p in legal )
166
+ return sorted (
213
167
({'action' : p ,
214
168
'average' : self .stats [(player , S )].value / self .stats [(player , S )].visits ,
215
169
'sum' : self .stats [(player , S )].value ,
@@ -218,75 +172,3 @@ def get_action(self):
218
172
key = lambda x : (x ['average' ], x ['plays' ]),
219
173
reverse = True
220
174
)
221
- for m in self .data ['actions' ]:
222
- print "{action}: {average:.1f} ({sum} / {plays})" .format (** m )
223
-
224
- # Pick the action with the highest average value.
225
- average , num_actions , action = max (
226
- (self .stats [(player , S )].value / self .stats [(player , S )].visits ,
227
- self .stats [(player , S )].visits ,
228
- p )
229
- for p , S in actions_states
230
- )
231
-
232
- return action
233
-
234
- def run_simulation (self ):
235
- # Plays out a "random" game from the current position,
236
- # then updates the statistics tables with the result.
237
-
238
- # A bit of an optimization here, so we have a local
239
- # variable lookup instead of an attribute access each loop.
240
- stats = self .stats
241
-
242
- visited_states = set ()
243
- history_copy = self .history [:]
244
- state = history_copy [- 1 ]
245
- player = self .board .current_player (state )
246
-
247
- expand = True
248
- for t in xrange (1 , self .max_actions + 1 ):
249
- legal = self .board .legal_actions (history_copy )
250
- actions_states = [(p , self .board .next_state (state , p )) for p in legal ]
251
-
252
- if all ((player , S ) in stats for p , S in actions_states ):
253
- # If we have stats on all of the legal actions here, use UCB1.
254
- log_total = log (
255
- sum (stats [(player , S )].visits for p , S in actions_states ))
256
- value , action , state = max (
257
- ((stats [(player , S )].value / stats [(player , S )].visits ) +
258
- self .C * sqrt (log_total / stats [(player , S )].visits ), p , S )
259
- for p , S in actions_states
260
- )
261
- else :
262
- # Otherwise, just make an arbitrary decision.
263
- action , state = choice (actions_states )
264
-
265
- history_copy .append (state )
266
-
267
- # `player` here and below refers to the player
268
- # who moved into that particular state.
269
- if expand and (player , state ) not in stats :
270
- expand = False
271
- stats [(player , state )] = Stat ()
272
- if t > self .max_depth :
273
- self .max_depth = t
274
-
275
- visited_states .add ((player , state ))
276
-
277
- player = self .board .current_player (state )
278
- winner = self .board .winner (history_copy )
279
- if winner :
280
- break
281
-
282
- player_values = {}
283
- for player , state in visited_states :
284
- if (player , state ) not in stats :
285
- continue
286
- if player not in player_values :
287
- player_values [player ] = self .board .end_value (history_copy , player )
288
-
289
- S = stats [(player , state )]
290
- S .visits += 1
291
- if player_values [player ] is not None :
292
- S .value += player_values [player ]
0 commit comments