Skip to content

Commit

Permalink
HalfKAv2
Browse files Browse the repository at this point in the history
  • Loading branch information
Sopel97 committed May 18, 2021
1 parent 6465d35 commit 11b755f
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 1 deletion.
3 changes: 2 additions & 1 deletion features.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
'''
import halfkp
import halfka
import halfka_v2

_feature_modules = [halfkp, halfka]
_feature_modules = [halfkp, halfka, halfka_v2]

_feature_blocks_by_name = dict()

Expand Down
86 changes: 86 additions & 0 deletions halfka_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import chess
import torch
import feature_block
from collections import OrderedDict
from feature_block import *

NUM_SQ = 64
NUM_PT_REAL = 11
NUM_PT_VIRTUAL = 12
NUM_PLANES_REAL = NUM_SQ * NUM_PT_REAL
NUM_PLANES_VIRTUAL = NUM_SQ * NUM_PT_VIRTUAL
NUM_INPUTS = NUM_PLANES_REAL * NUM_SQ

def orient(is_white_pov: bool, sq: int):
return (56 * (not is_white_pov)) ^ sq

def halfka_idx(is_white_pov: bool, king_sq: int, sq: int, p: chess.Piece):
p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov)
if p_idx == 11:
p_idx -= 1
return orient(is_white_pov, sq) + p_idx * NUM_SQ + king_sq * NUM_PLANES_REAL

def halfka_psqts():
# values copied from stockfish, in stockfish internal units
piece_values = {
chess.PAWN : 126,
chess.KNIGHT : 781,
chess.BISHOP : 825,
chess.ROOK : 1276,
chess.QUEEN : 2538
}

values = [0] * (NUM_PLANES_REAL * NUM_SQ)

for ksq in range(64):
for s in range(64):
for pt, val in piece_values.items():
idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE))
idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK))
values[idxw] = val
values[idxb] = -val

return values

class Features(FeatureBlock):
def __init__(self):
super(Features, self).__init__('HalfKAv2', 0x5f234cb8, OrderedDict([('HalfKAv2', NUM_PLANES_REAL * NUM_SQ)]))

def get_active_features(self, board: chess.Board):
def piece_features(turn):
indices = torch.zeros(NUM_PLANES_REAL * NUM_SQ)
for sq, p in board.piece_map().items():
indices[halfka_idx(turn, orient(turn, board.king(turn)), sq, p)] = 1.0
return indices
return (piece_features(chess.WHITE), piece_features(chess.BLACK))

def get_initial_psqt_features(self):
return halfka_psqts()

class FactorizedFeatures(FeatureBlock):
def __init__(self):
super(FactorizedFeatures, self).__init__('HalfKAv2^', 0x5f234cb8, OrderedDict([('HalfKAv2', NUM_PLANES_REAL * NUM_SQ), ('A', NUM_PLANES_VIRTUAL)]))

def get_active_features(self, board: chess.Board):
raise Exception('Not supported yet, you must use the c++ data loader for factorizer support during training')

def get_feature_factors(self, idx):
if idx >= self.num_real_features:
raise Exception('Feature must be real')

a_idx = idx % NUM_PLANES_REAL
k_idx = idx // NUM_PLANES_REAL

if a_idx // NUM_SQ == 10 and k_idx != a_idx % NUM_SQ:
a_idx += NUM_SQ

return [idx, self.get_factor_base_feature('A') + a_idx]

def get_initial_psqt_features(self):
return halfka_psqts() + [0] * NUM_PLANES_VIRTUAL

'''
This is used by the features module for discovery of feature blocks.
'''
def get_feature_block_clss():
return [Features, FactorizedFeatures]
79 changes: 79 additions & 0 deletions training_data_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,69 @@ struct HalfKAFactorized {
}
};

struct HalfKAv2 {
static constexpr int NUM_SQ = 64;
static constexpr int NUM_PT = 11;
static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
static constexpr int INPUTS = NUM_PLANES * NUM_SQ;

static constexpr int MAX_ACTIVE_FEATURES = 32;

static int feature_index(Color color, Square ksq, Square sq, Piece p)
{
auto p_idx = static_cast<int>(p.type()) * 2 + (p.color() != color);
if (p_idx == 11)
--p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ
return static_cast<int>(orient_flip(color, sq)) + p_idx * NUM_SQ + static_cast<int>(ksq) * NUM_PLANES;
}

static std::pair<int, int> fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color)
{
auto& pos = e.pos;
auto pieces = pos.piecesBB();
auto ksq = pos.kingSquare(color);

int j = 0;
for(Square sq : pieces)
{
auto p = pos.pieceAt(sq);
values[j] = 1.0f;
features[j] = feature_index(color, orient_flip(color, ksq), sq, p);
++j;
}

return { j, INPUTS };
}
};

struct HalfKAv2Factorized {
// Factorized features
static constexpr int PIECE_INPUTS = HalfKAv2::NUM_SQ * HalfKAv2::NUM_PT;
static constexpr int INPUTS = HalfKAv2::INPUTS + PIECE_INPUTS;

static constexpr int MAX_PIECE_FEATURES = 32;
static constexpr int MAX_ACTIVE_FEATURES = HalfKAv2::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES;

static std::pair<int, int> fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color)
{
const auto [start_j, offset] = HalfKAv2::fill_features_sparse(e, features, values, color);
auto& pos = e.pos;
auto pieces = pos.piecesBB();

int j = start_j;
for(Square sq : pieces)
{
auto p = pos.pieceAt(sq);
auto p_idx = static_cast<int>(p.type()) * 2 + (p.color() != color);
values[j] = 1.0f;
features[j] = offset + (p_idx * HalfKAv2::NUM_SQ) + static_cast<int>(orient_flip(color, sq));
++j;
}

return { j, INPUTS };
}
};

template <typename T, typename... Ts>
struct FeatureSet
{
Expand Down Expand Up @@ -512,6 +575,14 @@ extern "C" {
{
return new SparseBatch(FeatureSet<HalfKAFactorized>{}, entries);
}
else if (feature_set == "HalfKAv2")
{
return new SparseBatch(FeatureSet<HalfKAv2>{}, entries);
}
else if (feature_set == "HalfKAv2^")
{
return new SparseBatch(FeatureSet<HalfKAv2Factorized>{}, entries);
}
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
return nullptr;
}
Expand Down Expand Up @@ -559,6 +630,14 @@ extern "C" {
{
return new FeaturedBatchStream<FeatureSet<HalfKAFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
}
else if (feature_set == "HalfKAv2")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
}
else if (feature_set == "HalfKAv2^")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
}
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
return nullptr;
}
Expand Down

0 comments on commit 11b755f

Please sign in to comment.