Skip to content

Commit e7a15ab

Browse files
committed
Call it a day.
1 parent 9df3dd8 commit e7a15ab

13 files changed

+227
-33
lines changed

src/game.cpp

+20-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
using namespace std;
99
using namespace game;
1010

11-
Game::Game(const Board& b, Piece p): init_state(b, p) {}
11+
Game::Game(const Board& b, Piece p):
12+
init_state(b, p), win1(0), win2(0) {}
1213

13-
Game::Game(unique_ptr<Board>&& b, Piece p): init_state(std::move(b), p) {}
14+
Game::Game(unique_ptr<Board>&& b, Piece p):
15+
init_state(std::move(b), p), win1(0), win2(0) {}
1416

1517
State Game::result(const State& state, Point m, Piece p) const {
1618
auto new_board = state.copy_board();
@@ -51,7 +53,23 @@ void Game::play(const Player* p1, const Player* p2) const {
5153
i = (i+1) % 2;
5254
}
5355
s.display();
56+
switch(s.get_piece()) {
57+
case Piece::White: ++win1; break;
58+
case Piece::Black: ++win2; break;
59+
default:
60+
throw logic_error("Game ends with Blank piece");
61+
}
5462
cout << get_piece_color(s.get_piece()) << " won!\n";
5563
cout << get_piece_color(p1->get_piece()) << " utility: " << s.get_utility(p1->get_piece()) << '\n';
5664
cout << get_piece_color(p2->get_piece()) << " utility: " << s.get_utility(p2->get_piece()) << '\n';
5765
}
66+
67+
double Game::compute_win_rate(Piece p) const {
68+
double x;
69+
switch(p) {
70+
case Piece::White: x = win1; break;
71+
case Piece::Black: x = win2; break;
72+
default: x = 0; break;
73+
}
74+
return x / (win1 + win2);
75+
}

src/game.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ namespace game {
1616
Game(std::unique_ptr<Board>&& b, Piece p=Piece::Blank);
1717

1818
void play(const Player*, const Player*) const;
19-
// copy the state, update it and return
2019
State result(const State&, Point, Piece) const;
2120
State get_initial_state() const { return init_state; }
21+
double compute_win_rate(Piece) const;
2222
private:
2323
State init_state;
24+
mutable int win1;
25+
mutable int win2;
2426
};
2527
}
2628

src/hex_board.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ double HexBoard::compute_utility(Piece piece) const {
9898
}
9999
auto white_score = static_cast<double>(down - top + 1) / n;
100100
auto black_score = static_cast<double>(right - left + 1) / n;
101-
// display();
102-
// cout << piece << down - top + 1 << " " << right - left + 1 << endl;
103-
// assert(false);
104101
return piece == Piece::White? white_score - black_score: black_score - white_score;
105102
}
106103

src/hex_board.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace game {
1111
class HexBoard: public RhombusBoard {
1212
public:
13-
explicit HexBoard(std::size_t nn=11): RhombusBoard(nn) {}
13+
explicit HexBoard(std::size_t nn=9): RhombusBoard(nn) {}
1414

1515
void display(std::ostream& os=std::cout) const final;
1616
HexBoard* clone() const final { return new HexBoard(*this);}

src/main.cpp

+34-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "alphabeta_cutoff_player.hpp"
1313
#include "minmax_player.hpp"
1414
#include "mcts_player.hpp"
15+
#include "mcts_eval_multithread_player.hpp"
1516
#include "mcts_multithread_player.hpp"
1617
#include "mcts_multithread_threadpool_player.hpp"
1718

@@ -27,7 +28,7 @@ using namespace utility;
2728

2829
string ask_for_player_type(int player_idx) {
2930
string pt;
30-
cout << "Select player" << player_idx << "(AlphaBeta AlphaBetaCutOff MCTS MCTSMT MinMax): ";
31+
cout << "Select player" << player_idx << "(AlphaBeta AlphaBetaCutOff MCTS MCTSEMT MCTSMT MinMax): ";
3132
cin >> pt;
3233
transform(pt.begin(), pt.end(), pt.begin(), ::toupper);
3334
return pt;
@@ -36,7 +37,21 @@ string ask_for_player_type(int player_idx) {
3637
int ask_for_n(const string& p, int player_idx) {
3738
if (p == "ALPHABETA" || p == "MINMAX")
3839
return 0;
39-
cout << "Player" << player_idx <<"-the number of simulations/The depth to cutoff: ";
40+
string s;
41+
if (p == "ALPHABETACUTOFF") {
42+
s = "The depth to cutoff";
43+
}
44+
else {
45+
s = "The number of simulations";
46+
}
47+
cout << "(Player" << player_idx <<")" + s + ": ";
48+
int n;
49+
cin >> n;
50+
return n;
51+
}
52+
53+
int ask_for_experiments() {
54+
cout << "Number of experiments: ";
4055
int n;
4156
cin >> n;
4257
return n;
@@ -51,33 +66,44 @@ unique_ptr<Player> create_player(const string& s, Piece p, int n) {
5166
return make_unique<MCTSPlayer>(p, n);
5267
if (s == "MCTSMT")
5368
return make_unique<MCTSMultiThreadPlayer>(p, n);
69+
if (s == "MCTSEMT")
70+
return make_unique<MCTSEvaluationMultiThreadPlayer>(p, n);
5471
if (s == "MCTSMTTP")
5572
return make_unique<MCTSMultiThreadThreadPoolPlayer>(p, n);
5673
if (s == "MINMAX")
5774
return make_unique<MinMaxPlayer>(p);
5875
throw invalid_argument("Invalid player type: " + s);
5976
}
6077

78+
void print_win_rate(const Game& g, const string& p_type, const Player* p) {
79+
cout << get_piece_color(p->get_piece()) << "(" << p_type
80+
<< ") win rate: " << g.compute_win_rate(p->get_piece())
81+
<< "\n";
82+
}
83+
6184
int main() {
6285
cout << "Board size: ";
6386
size_t n;
6487
cin >> n;
65-
string p1_type = ask_for_player_type(1);
66-
string p2_type = ask_for_player_type(2);
88+
auto p1_type = ask_for_player_type(1);
89+
auto p2_type = ask_for_player_type(2);
6790
int n1 = ask_for_n(p1_type, 1);
6891
int n2 = ask_for_n(p2_type, 2);
6992

7093
auto p1 = create_player(p1_type, Piece::White, n1);
7194
auto p2 = create_player(p2_type, Piece::Black, n2);
72-
// int n = 5;
95+
int n_exps = ask_for_experiments();
96+
// int n = 9;
7397
// string p1_type = "MCTS";
7498
// string p2_type = "MCTS";
75-
// auto p1 = make_unique<MCTSPlayer>(Piece::White, 100);
76-
// auto p2 = make_unique<MCTSPlayer>(Piece::Black, 100);
99+
// auto p1 = make_unique<MCTSMultiThreadThreadPoolPlayer>(Piece::White, 1000);
100+
// auto p2 = make_unique<MCTSMultiThreadThreadPoolPlayer>(Piece::Black, 1000);
77101
string description = "Time to run (" + p1_type + " vs " + p2_type +
78102
") on (" + to_string(n) + " x " + to_string(n) + ") board: ";
79103
Game g(make_board("Hex", n));
80-
for (auto i = 0; i != 10; ++i) {
104+
for (auto i = 0; i != n_exps; ++i) {
81105
timeit(description, &g, &Game::play, p1.get(), p2.get());
82106
}
107+
print_win_rate(g, p1_type, p1.get());
108+
print_win_rate(g, p2_type, p2.get());
83109
}

src/mcts_eval_multithread_player.cpp

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include <cmath>
2+
#include <limits>
3+
#include <random>
4+
#include <future>
5+
6+
#include "mcts_eval_multithread_player.hpp"
7+
8+
using namespace std;
9+
using namespace game;
10+
11+
12+
MCTSEvaluationMultiThreadPlayer::MCTSEvaluationMultiThreadPlayer(Piece p, int nn, double cc):
13+
Player(p), n(nn), c(cc) {}
14+
15+
Point MCTSEvaluationMultiThreadPlayer::get_move(const Game& game, const State& state) const {
16+
assert(!state.is_over());
17+
auto root = make_unique<Node>(state);
18+
vector<future<void>> v(n);
19+
for (auto i = 0; i != n; ++i) {
20+
v[i] = std::async([this, game, p=root.get()]() {
21+
auto leaf = select_expand(game, p);
22+
auto result = simulate(game, leaf->state);
23+
backup(leaf, result);
24+
});
25+
}
26+
for (auto& f: v)
27+
f.get();
28+
int max_i = 0;
29+
auto max_n = 0;
30+
const auto& moves = state.get_valid_moves();
31+
const auto& children = root->children;
32+
assert(moves.size() == children.size());
33+
for (auto i = 0; i != moves.size(); ++i) {
34+
if (children[i] && children[i]->n > max_n)
35+
max_i = i;
36+
}
37+
return moves[max_i];
38+
}
39+
40+
double MCTSEvaluationMultiThreadPlayer::ucb(Node* p) const {
41+
// traditional ucb
42+
if (p) {
43+
lock_guard lk{p->m};
44+
return p->utility / p->n + c * sqrt(log(p->parent->n) / p->n);
45+
}
46+
return numeric_limits<double>::max();
47+
}
48+
49+
MCTSEvaluationMultiThreadPlayer::Node* MCTSEvaluationMultiThreadPlayer::select_expand(const Game& game, Node* p) const {
50+
assert(p);
51+
assert(!p->children.empty());
52+
double max_val = numeric_limits<double>::min();
53+
int max_idx = 0;
54+
vector<double> v(p->children.size());
55+
for (auto i = 0; i != p->children.size(); ++i) {
56+
v[i] = ucb(p->children[i].get());
57+
if (v[i] > max_val) {
58+
max_val = v[i];
59+
max_idx = i;
60+
if (max_val == numeric_limits<double>::max())
61+
break;
62+
}
63+
}
64+
if (!p->children[max_idx]) {
65+
make_child(game, p, max_idx); // expansion happens here, introducing an additional expand could incur additional O(N) the memory, where N is the number of leaf nodes
66+
return p->children[max_idx].get();
67+
}
68+
else if (p->children[max_idx]->state.is_over()) {
69+
return p->children[max_idx].get();
70+
}
71+
else {
72+
return select_expand(game, p->children[max_idx].get());
73+
}
74+
}
75+
76+
double MCTSEvaluationMultiThreadPlayer::simulate(const Game& game, const State& s) const {
77+
s.compute_utility();
78+
return s.get_utility(get_piece());
79+
}
80+
81+
void MCTSEvaluationMultiThreadPlayer::backup(Node* p, double v) const {
82+
assert(p);
83+
{
84+
lock_guard lk(p->m);
85+
++p->n;
86+
p->utility += v;
87+
}
88+
if (p->parent)
89+
backup(p->parent, v);
90+
}
91+
void MCTSEvaluationMultiThreadPlayer::make_child(const Game& game, Node* p, int idx) const {
92+
assert(p);
93+
auto moves = p->state.get_valid_moves();
94+
auto piece = p->state.to_move();
95+
auto state = game.result(p->state, moves[idx], piece);
96+
lock_guard lk(p->m);
97+
if (!p->children[idx]) {
98+
p->children[idx] = make_unique<Node>(std::move(state), p);
99+
}
100+
}

src/mcts_eval_multithread_player.hpp

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifndef GAME_MCTS_EVALUATION_MULTITHREAD_PLAYER_H_
2+
#define GAME_MCTS_EVALUATION_MULTITHREAD_PLAYER_H_
3+
4+
#include <memory>
5+
#include <mutex>
6+
#include <utility>
7+
#include <vector>
8+
9+
#include "game.hpp"
10+
#include "player.hpp"
11+
#include "state.hpp"
12+
13+
namespace game {
14+
class MCTSEvaluationMultiThreadPlayer: public Player {
15+
public:
16+
MCTSEvaluationMultiThreadPlayer(Piece p, int nn, double cc=1.4);
17+
18+
Point get_move(const Game&, const State&) const override;
19+
private:
20+
struct Node {
21+
Node(const State& s, Node* p=nullptr):
22+
state(s), parent(p),
23+
children(state.get_valid_moves().size()),
24+
n(0), utility(0) {}
25+
Node(State&& s, Node* p=nullptr):
26+
state(std::move(s)), parent(p),
27+
children(state.get_valid_moves().size()),
28+
n(0), utility(0) {}
29+
30+
State state;
31+
Node* parent;
32+
std::vector<std::unique_ptr<Node>> children;
33+
std::size_t n;
34+
double utility;
35+
std::mutex m;
36+
};
37+
double ucb(Node* p) const;
38+
Node* select_expand(const Game&, Node*) const;
39+
double simulate(const Game&, const State&) const;
40+
void backup(Node*, double) const;
41+
void make_child(const Game&, Node*, int) const;
42+
int n; // number of simulations
43+
double c;
44+
};
45+
}
46+
47+
#endif

src/mcts_mt_lf_player.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Point MCTSMultiThreadLockFreePlayer::get_move(const Game& game, const State& sta
3131
const auto& children = root->children;
3232
assert(moves.size() == children.size());
3333
for (auto i = 0; i != moves.size(); ++i) {
34-
if (children[i]->n > max_n)
34+
if (children[i] && children[i]->n > max_n)
3535
max_i = i;
3636
}
3737
return moves[max_i];

src/mcts_multithread_player.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Point MCTSMultiThreadPlayer::get_move(const Game& game, const State& state) cons
3131
const auto& children = root->children;
3232
assert(moves.size() == children.size());
3333
for (auto i = 0; i != moves.size(); ++i) {
34-
if (children[i]->n > max_n)
34+
if (children[i] && children[i]->n > max_n)
3535
max_i = i;
3636
}
3737
return moves[max_i];

src/mcts_multithread_threadpool_player.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,22 @@ Point MCTSMultiThreadThreadPoolPlayer::get_move(
1818
const Game& game, const State& state) const {
1919
assert(!state.is_over());
2020
auto root = make_unique<Node>(state);
21-
auto n_threads = std::thread::hardware_concurrency() - 1;
2221
{
22+
auto n_threads = std::thread::hardware_concurrency() - 1;
23+
int sim_per_thread = n / (n_threads+1);
2324
ThreadPool<void()> tp(n_threads);
2425
vector<future<void>> v(n_threads);
25-
int sim_per_thread = n / (n_threads+1);
2626
auto task = [this](const Game& game, Node* root, int sim_per_thread) {
2727
for (auto i = 0; i != sim_per_thread; ++i) {
28+
assert(root);
2829
auto leaf = select_expand(game, root);
30+
assert(leaf);
2931
auto result = simulate(game, leaf->state);
32+
assert(result != 0);
3033
backup(leaf, result);
3134
}
3235
};
33-
for (auto i = 0; i != n_threads - 1; ++i) {
36+
for (auto i = 0; i != n_threads; ++i) {
3437
v[i] = tp.submit(task, game, root.get(), sim_per_thread);
3538
}
3639
task(game, root.get(), n - sim_per_thread * n_threads);
@@ -43,7 +46,7 @@ Point MCTSMultiThreadThreadPoolPlayer::get_move(
4346
const auto& children = root->children;
4447
assert(moves.size() == children.size());
4548
for (auto i = 0; i != moves.size(); ++i) {
46-
if (children[i]->n > max_n)
49+
if (children[i] && children[i]->n > max_n)
4750
max_i = i;
4851
}
4952
return moves[max_i];

src/mcts_player.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Point MCTSPlayer::get_move(const Game& game, const State& state) const {
2525
const auto& children = root->children;
2626
assert(moves.size() == children.size());
2727
for (auto i = 0; i != moves.size(); ++i) {
28-
if (children[i]->n > max_n)
28+
if (children[i] && children[i]->n > max_n)
2929
max_i = i;
3030
}
3131
return moves[max_i];

0 commit comments

Comments
 (0)