-
Notifications
You must be signed in to change notification settings - Fork 8
/
SelfPlayAgent.pyx
122 lines (113 loc) · 5.14 KB
/
SelfPlayAgent.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# cython: language_level=3
import numpy as np
import torch
import torch.multiprocessing as mp
from MCTS import MCTS
class SelfPlayAgent(mp.Process):
def __init__(self, id, game, ready_queue, batch_ready, batch_tensor, policy_tensor, value_tensor, output_queue,
result_queue, complete_count, games_played, args):
super().__init__()
self.id = id
self.game = game
self.ready_queue = ready_queue
self.batch_ready = batch_ready
self.batch_tensor = batch_tensor
self.batch_size = self.batch_tensor.shape[0]
self.policy_tensor = policy_tensor
self.value_tensor = value_tensor
self.output_queue = output_queue
self.result_queue = result_queue
self.games = []
self.canonical = []
self.histories = []
self.player = []
self.turn = []
self.mcts = []
self.games_played = games_played
self.complete_count = complete_count
self.args = args
self.valid = torch.zeros_like(self.policy_tensor)
self.fast = False
for _ in range(self.batch_size):
self.games.append(self.game.getInitBoard())
self.histories.append([])
self.player.append(1)
self.turn.append(1)
self.mcts.append(MCTS(self.game, None, self.args))
self.canonical.append(None)
def run(self):
np.random.seed()
while self.games_played.value < self.args.gamesPerIteration:
self.generateCanonical()
self.fast = np.random.random_sample() < self.args.probFastSim
if self.fast:
for i in range(self.args.numFastSims):
self.generateBatch()
self.processBatch()
else:
for i in range(self.args.numMCTSSims):
self.generateBatch()
self.processBatch()
self.playMoves()
with self.complete_count.get_lock():
self.complete_count.value += 1
self.output_queue.close()
self.output_queue.join_thread()
def generateBatch(self):
for i in range(self.batch_size):
board = self.mcts[i].findLeafToProcess(self.canonical[i], True)
if board is not None:
self.batch_tensor[i] = torch.from_numpy(board)
self.ready_queue.put(self.id)
def processBatch(self):
self.batch_ready.wait()
self.batch_ready.clear()
for i in range(self.batch_size):
self.mcts[i].processResults(
self.policy_tensor[i].data.numpy(), self.value_tensor[i][0])
def playMoves(self):
for i in range(self.batch_size):
temp = int(self.turn[i] < self.args.tempThreshold)
policy = self.mcts[i].getExpertProb(
self.canonical[i], temp, not self.fast)
action = np.random.choice(len(policy), p=policy)
if not self.fast:
self.histories[i].append((self.canonical[i], self.mcts[i].getExpertProb(self.canonical[i], prune=True),
self.mcts[i].getExpertValue(self.canonical[i]), self.player[i]))
self.games[i], self.player[i] = self.game.getNextState(
self.games[i], self.player[i], action)
self.turn[i] += 1
winner = self.game.getGameEnded(self.games[i], 1)
if winner != 0:
self.result_queue.put(winner)
lock = self.games_played.get_lock()
lock.acquire()
if self.games_played.value < self.args.gamesPerIteration:
self.games_played.value += 1
lock.release()
for hist in self.histories[i]:
if self.args.symmetricSamples:
sym = self.game.getSymmetries(hist[0], hist[1])
for b, p in sym:
self.output_queue.put((b, p,
winner *
hist[3] *
(1 - self.args.expertValueWeight.current)
+ self.args.expertValueWeight.current * hist[2]))
else:
self.output_queue.put((hist[0], hist[1],
winner *
hist[3] *
(1 - self.args.expertValueWeight.current)
+ self.args.expertValueWeight.current * hist[2]))
self.games[i] = self.game.getInitBoard()
self.histories[i] = []
self.player[i] = 1
self.turn[i] = 1
self.mcts[i] = MCTS(self.game, None, self.args)
else:
lock.release()
def generateCanonical(self):
for i in range(self.batch_size):
self.canonical[i] = self.game.getCanonicalForm(
self.games[i], self.player[i])