-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_train.py
61 lines (53 loc) · 1.5 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from chess_standard.board_chess_pypi import BoardPypiChess
from chess_standard.chessnet import ChessNet
from common.parallel_player import ParallelPlayer
from common.serial_player import SerialPlayer
from common.training_loop import train
from connect4.board_c4 import BoardC4
from connect4.c4net import C4Net
import torch.multiprocessing as mp
import sys
sys.path.append("/home/jovyan/work/MANTIS")
sys.path.append("C:\\Users\\xiayi\\Desktop\\1. Duke University Classes\\MANTIS")
if __name__ == "__main__":
mp.set_start_method("spawn")
MAX_ITERATIONS = 1000
EPOCHS_PER_ITERATION = 10
NUM_GENERATED = 10
BATCH_SIZE = 15
GAMES_TO_EVAL = 3
MCTS_ITER = 5
START_ITERATION = 0
old_exists = False
# MAX_ITERATIONS = 1
# EPOCHS_PER_ITERATION = 1
# NUM_GENERATED = 6
# BATCH_SIZE = 1
# GAMES_TO_EVAL = 6
# MCTS_ITER = 50
# START_ITERATION = 0
# old_exists = False
SAVE_DIR = "chessdata1"
TEMP_NAME = "old.pt"
multicore = 1
Net = ChessNet
Board = BoardPypiChess
# Net = C4Net
# Board = BoardC4
# player = SerialPlayer(MCTS_ITER, old_exists, SAVE_DIR, TEMP_NAME, multicore, Net, Board)
player = ParallelPlayer(
MCTS_ITER, old_exists, SAVE_DIR, TEMP_NAME, multicore, Net, Board
)
train(
player,
Net,
MAX_ITERATIONS,
EPOCHS_PER_ITERATION,
NUM_GENERATED,
BATCH_SIZE,
GAMES_TO_EVAL,
START_ITERATION,
old_exists,
SAVE_DIR,
TEMP_NAME,
)