Skip to content

Commit 9df3dd8

Browse files
committed
Call it a day.
1 parent 58c351c commit 9df3dd8

33 files changed

+708
-4
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
main.dSYM
55
.vscode
66
hierarchy/build
7-
.template
7+
.template
8+
build

hierarchy/CMakeLists.txt CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED True)
55

66
project(Game)
77

8-
AUX_SOURCE_DIRECTORY(${CMAKE_CURRENT_SOURCE_DIR} SRCS)
8+
AUX_SOURCE_DIRECTORY(src SRCS)
9+
AUX_SOURCE_DIRECTORY(thread SRCS)
910
add_executable(game ${SRCS})
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

hierarchy/board.cpp src/board.cpp

File renamed without changes.

hierarchy/board.hpp src/board.hpp

File renamed without changes.
File renamed without changes.

hierarchy/game.cpp src/game.cpp

File renamed without changes.

hierarchy/game.hpp src/game.hpp

File renamed without changes.
File renamed without changes.
File renamed without changes.

hierarchy/main.cpp src/main.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "minmax_player.hpp"
1414
#include "mcts_player.hpp"
1515
#include "mcts_multithread_player.hpp"
16+
#include "mcts_multithread_threadpool_player.hpp"
1617

1718
#include "board.hpp"
1819
#include "game.hpp"
@@ -50,6 +51,8 @@ unique_ptr<Player> create_player(const string& s, Piece p, int n) {
5051
return make_unique<MCTSPlayer>(p, n);
5152
if (s == "MCTSMT")
5253
return make_unique<MCTSMultiThreadPlayer>(p, n);
54+
if (s == "MCTSMTTP")
55+
return make_unique<MCTSMultiThreadThreadPoolPlayer>(p, n);
5356
if (s == "MINMAX")
5457
return make_unique<MinMaxPlayer>(p);
5558
throw invalid_argument("Invalid player type: " + s);

src/mcts_mt_lf_player.cpp

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#include <cmath>
2+
#include <limits>
3+
#include <random>
4+
#include <future>
5+
6+
#include "mcts_mt_lf_player.hpp"
7+
8+
using namespace std;
9+
using namespace game;
10+
11+
12+
MCTSMultiThreadLockFreePlayer::MCTSMultiThreadLockFreePlayer(Piece p, int nn, double cc):
13+
Player(p), n(nn), c(cc) {}
14+
15+
Point MCTSMultiThreadLockFreePlayer::get_move(const Game& game, const State& state) const {
16+
assert(!state.is_over());
17+
auto root = make_unique<Node>(state);
18+
vector<future<void>> v(n);
19+
for (auto i = 0; i != n; ++i) {
20+
v[i] = std::async([this, game, p=root.get()]() {
21+
auto leaf = select_expand(game, p);
22+
auto result = simulate(game, leaf->state);
23+
backup(leaf, result);
24+
});
25+
}
26+
for (auto& f: v)
27+
f.get();
28+
int max_i = 0;
29+
auto max_n = 0;
30+
const auto& moves = state.get_valid_moves();
31+
const auto& children = root->children;
32+
assert(moves.size() == children.size());
33+
for (auto i = 0; i != moves.size(); ++i) {
34+
if (children[i]->n > max_n)
35+
max_i = i;
36+
}
37+
return moves[max_i];
38+
}
39+
40+
double MCTSMultiThreadLockFreePlayer::ucb(Node* p) const {
41+
// traditional ucb
42+
if (p) {
43+
return p->utility.load(memory_order_relaxed) / p->n.load(memory_order_relaxed)
44+
+ c * sqrt(log(p->parent->n.load(memory_order_relaxed)) / p->n.load(memory_order_relaxed));
45+
}
46+
return numeric_limits<double>::max();
47+
}
48+
49+
MCTSMultiThreadLockFreePlayer::Node* MCTSMultiThreadLockFreePlayer::select_expand(const Game& game, Node* p) const {
50+
assert(p);
51+
assert(!p->children.empty());
52+
double max_val = numeric_limits<double>::min();
53+
int max_idx = 0;
54+
vector<double> v(p->children.size());
55+
for (auto i = 0; i != p->children.size(); ++i) {
56+
v[i] = ucb(p->children[i].get());
57+
if (v[i] > max_val) {
58+
max_val = v[i];
59+
max_idx = i;
60+
if (max_val == numeric_limits<double>::max())
61+
break;
62+
}
63+
}
64+
if (!p->children[max_idx]) {
65+
make_child(game, p, max_idx); // expansion happens here, introducing an additional expand could incur additional O(N) the memory, where N is the number of leaf nodes
66+
return p->children[max_idx].get();
67+
}
68+
else if (p->children[max_idx]->state.is_over()) {
69+
return p->children[max_idx].get();
70+
}
71+
else {
72+
return select_expand(game, p->children[max_idx].get());
73+
}
74+
}
75+
76+
int MCTSMultiThreadLockFreePlayer::simulate(const Game& game, const State& s) const {
77+
if (s.is_over())
78+
return s.get_utility(get_piece());
79+
assert(!s.get_valid_moves().empty());
80+
auto state = game.result(s, s.get_valid_moves()[0], s.to_move());
81+
while (!state.is_over()) {
82+
state = game.result(
83+
state,
84+
state.get_valid_moves()[0],
85+
state.to_move());
86+
}
87+
return state.get_utility(get_piece());
88+
}
89+
90+
void MCTSMultiThreadLockFreePlayer::backup(Node* p, int v) const {
91+
assert(p);
92+
{
93+
p->n.fetch_add(1, memory_order_relaxed);
94+
auto util = p->utility.load(memory_order_relaxed);
95+
while (!p->utility.compare_exchange_weak(util, v,
96+
memory_order_relaxed, memory_order_relaxed));
97+
}
98+
if (p->parent)
99+
backup(p->parent, v);
100+
}
101+
void MCTSMultiThreadLockFreePlayer::make_child(const Game& game, Node* p, int idx) const {
102+
assert(p);
103+
auto moves = p->state.get_valid_moves();
104+
auto piece = p->state.to_move();
105+
auto state = game.result(p->state, moves[idx], piece);
106+
lock_guard lk(p->m);
107+
if (!p->children[idx]) {
108+
p->children[idx] = make_unique<Node>(std::move(state), p);
109+
}
110+
}

src/mcts_mt_lf_player.hpp

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ifndef GAME_MCTS_MULTITHREAD_LOCKFREE_PLAYER_H_
2+
#define GAME_MCTS_MULTITHREAD_LOCKFREE_PLAYER_H_
3+
4+
#include <atomic>
5+
#include <memory>
6+
#include <mutex>
7+
#include <utility>
8+
#include <vector>
9+
10+
#include "state.hpp"
11+
#include "player.hpp"
12+
#include "game.hpp"
13+
14+
namespace game {
15+
class MCTSMultiThreadLockFreePlayer: public Player {
16+
public:
17+
MCTSMultiThreadLockFreePlayer(Piece p, int nn, double cc=1.4);
18+
19+
Point get_move(const Game&, const State&) const override;
20+
private:
21+
struct Node {
22+
Node(const State& s, Node* p=nullptr):
23+
state(s), parent(p),
24+
children(state.get_valid_moves().size()),
25+
n(0), utility(0) {}
26+
Node(State&& s, Node* p=nullptr):
27+
state(std::move(s)), parent(p),
28+
children(state.get_valid_moves().size()),
29+
n(0), utility(0) {}
30+
31+
State state;
32+
Node* parent;
33+
std::vector<std::unique_ptr<Node>> children;
34+
std::atomic_size_t n;
35+
std::atomic<double> utility;
36+
std::mutex m;
37+
};
38+
double ucb(Node* p) const;
39+
Node* select_expand(const Game&, Node*) const;
40+
int simulate(const Game&, const State&) const;
41+
void backup(Node*, int) const;
42+
void make_child(const Game&, Node*, int) const;
43+
int n; // number of simulations
44+
double c;
45+
};
46+
}
47+
48+
#endif
File renamed without changes.

hierarchy/mcts_multithread_player.hpp src/mcts_multithread_player.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include <utility>
77
#include <vector>
88

9-
#include "state.hpp"
10-
#include "player.hpp"
119
#include "game.hpp"
10+
#include "player.hpp"
11+
#include "state.hpp"
1212

1313
namespace game {
1414
class MCTSMultiThreadPlayer: public Player {
+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#include <cmath>
2+
#include <limits>
3+
#include <random>
4+
#include <future>
5+
6+
#include "../thread/thread_pool.hpp"
7+
#include "mcts_multithread_threadpool_player.hpp"
8+
9+
using namespace std;
10+
using namespace game;
11+
using namespace thread_utils;
12+
13+
14+
MCTSMultiThreadThreadPoolPlayer::MCTSMultiThreadThreadPoolPlayer(
15+
Piece p, int nn, double cc): Player(p), n(nn), c(cc) {}
16+
17+
Point MCTSMultiThreadThreadPoolPlayer::get_move(
18+
const Game& game, const State& state) const {
19+
assert(!state.is_over());
20+
auto root = make_unique<Node>(state);
21+
auto n_threads = std::thread::hardware_concurrency() - 1;
22+
{
23+
ThreadPool<void()> tp(n_threads);
24+
vector<future<void>> v(n_threads);
25+
int sim_per_thread = n / (n_threads+1);
26+
auto task = [this](const Game& game, Node* root, int sim_per_thread) {
27+
for (auto i = 0; i != sim_per_thread; ++i) {
28+
auto leaf = select_expand(game, root);
29+
auto result = simulate(game, leaf->state);
30+
backup(leaf, result);
31+
}
32+
};
33+
for (auto i = 0; i != n_threads - 1; ++i) {
34+
v[i] = tp.submit(task, game, root.get(), sim_per_thread);
35+
}
36+
task(game, root.get(), n - sim_per_thread * n_threads);
37+
for (auto& f: v)
38+
f.get();
39+
}
40+
int max_i = 0;
41+
auto max_n = 0;
42+
const auto& moves = state.get_valid_moves();
43+
const auto& children = root->children;
44+
assert(moves.size() == children.size());
45+
for (auto i = 0; i != moves.size(); ++i) {
46+
if (children[i]->n > max_n)
47+
max_i = i;
48+
}
49+
return moves[max_i];
50+
}
51+
52+
double MCTSMultiThreadThreadPoolPlayer::ucb(Node* p) const {
53+
// traditional ucb
54+
if (p) {
55+
lock_guard lk{p->m};
56+
return p->utility / p->n + c * sqrt(log(p->parent->n) / p->n);
57+
}
58+
return numeric_limits<double>::max();
59+
}
60+
61+
MCTSMultiThreadThreadPoolPlayer::Node* MCTSMultiThreadThreadPoolPlayer::select_expand(const Game& game, Node* p) const {
62+
assert(p);
63+
assert(!p->children.empty());
64+
double max_val = numeric_limits<double>::min();
65+
int max_idx = 0;
66+
vector<double> v(p->children.size());
67+
for (auto i = 0; i != p->children.size(); ++i) {
68+
v[i] = ucb(p->children[i].get());
69+
if (v[i] > max_val) {
70+
max_val = v[i];
71+
max_idx = i;
72+
if (max_val == numeric_limits<double>::max())
73+
break;
74+
}
75+
}
76+
if (!p->children[max_idx]) {
77+
make_child(game, p, max_idx); // expansion happens here, introducing an additional expand could incur additional O(N) the memory, where N is the number of leaf nodes
78+
return p->children[max_idx].get();
79+
}
80+
else if (p->children[max_idx]->state.is_over()) {
81+
return p->children[max_idx].get();
82+
}
83+
else {
84+
return select_expand(game, p->children[max_idx].get());
85+
}
86+
}
87+
88+
int MCTSMultiThreadThreadPoolPlayer::simulate(const Game& game, const State& s) const {
89+
if (s.is_over())
90+
return s.get_utility(get_piece());
91+
assert(!s.get_valid_moves().empty());
92+
auto state = game.result(s, s.get_valid_moves()[0], s.to_move());
93+
while (!state.is_over()) {
94+
state = game.result(
95+
state,
96+
state.get_valid_moves()[0],
97+
state.to_move());
98+
}
99+
return state.get_utility(get_piece());
100+
}
101+
102+
void MCTSMultiThreadThreadPoolPlayer::backup(Node* p, int v) const {
103+
assert(p);
104+
{
105+
lock_guard lk(p->m);
106+
++p->n;
107+
p->utility += v;
108+
}
109+
if (p->parent)
110+
backup(p->parent, v);
111+
}
112+
void MCTSMultiThreadThreadPoolPlayer::make_child(const Game& game, Node* p, int idx) const {
113+
assert(p);
114+
auto moves = p->state.get_valid_moves();
115+
auto piece = p->state.to_move();
116+
auto state = game.result(p->state, moves[idx], piece);
117+
lock_guard lk(p->m);
118+
if (!p->children[idx]) {
119+
p->children[idx] = make_unique<Node>(std::move(state), p);
120+
}
121+
}
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifndef GAME_MCTS_MULTITHREAD_THREADPOOL_PLAYER_H_
2+
#define GAME_MCTS_MULTITHREAD_THREADPOOL_PLAYER_H_
3+
4+
#include <memory>
5+
#include <mutex>
6+
#include <utility>
7+
#include <vector>
8+
9+
#include "game.hpp"
10+
#include "player.hpp"
11+
#include "state.hpp"
12+
13+
namespace game {
14+
class MCTSMultiThreadThreadPoolPlayer: public Player {
15+
public:
16+
MCTSMultiThreadThreadPoolPlayer(Piece p, int nn, double cc=1.4);
17+
18+
Point get_move(const Game&, const State&) const override;
19+
private:
20+
struct Node {
21+
Node(const State& s, Node* p=nullptr):
22+
state(s), parent(p),
23+
children(state.get_valid_moves().size()),
24+
n(0), utility(0) {}
25+
Node(State&& s, Node* p=nullptr):
26+
state(std::move(s)), parent(p),
27+
children(state.get_valid_moves().size()),
28+
n(0), utility(0) {}
29+
30+
State state;
31+
Node* parent;
32+
std::vector<std::unique_ptr<Node>> children;
33+
std::size_t n;
34+
double utility;
35+
std::mutex m;
36+
};
37+
double ucb(Node* p) const;
38+
Node* select_expand(const Game&, Node*) const;
39+
int simulate(const Game&, const State&) const;
40+
void backup(Node*, int) const;
41+
void make_child(const Game&, Node*, int) const;
42+
int n; // number of simulations
43+
double c;
44+
};
45+
}
46+
47+
#endif
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

hierarchy/state.cpp src/state.cpp

File renamed without changes.

hierarchy/state.hpp src/state.hpp

File renamed without changes.

hierarchy/utils.hpp src/utils.hpp

File renamed without changes.

0 commit comments

Comments
 (0)