Skip to content

Commit

Permalink
halfkav2_hm
Browse files Browse the repository at this point in the history
  • Loading branch information
Sopel97 committed Aug 9, 2021
1 parent b50c70d commit 39d379a
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 1 deletion.
3 changes: 2 additions & 1 deletion features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import halfkp
import halfka
import halfka_v2
import halfka_v2_hm

_feature_modules = [halfkp, halfka, halfka_v2]
_feature_modules = [halfkp, halfka, halfka_v2, halfka_v2_hm]

_feature_blocks_by_name = dict()

Expand Down
95 changes: 95 additions & 0 deletions halfka_v2_hm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import chess
import torch
import feature_block
from collections import OrderedDict
from feature_block import *

NUM_SQ = 64
NUM_PT_REAL = 11
NUM_PT_VIRTUAL = 12
NUM_PLANES_REAL = NUM_SQ * NUM_PT_REAL
NUM_PLANES_VIRTUAL = NUM_SQ * NUM_PT_VIRTUAL
NUM_INPUTS = NUM_PLANES_REAL * NUM_SQ // 2

KingBuckets = [
-1, -1, -1, -1, 31, 30, 29, 28,
-1, -1, -1, -1, 27, 26, 25, 24,
-1, -1, -1, -1, 23, 22, 21, 20,
-1, -1, -1, -1, 19, 18, 17, 16,
-1, -1, -1, -1, 15, 14, 13, 12,
-1, -1, -1, -1, 11, 10, 9, 8,
-1, -1, -1, -1, 7, 6, 5, 4,
-1, -1, -1, -1, 3, 2, 1, 0
]

def orient(is_white_pov: bool, sq: int, ksq: int):
# ksq must not be oriented
kfile = (ksq % 8)
return (7 * (kfile < 4)) ^ (56 * (not is_white_pov)) ^ sq

def halfka_idx(is_white_pov: bool, king_sq: int, sq: int, p: chess.Piece):
p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov)
o_ksq = orient(is_white_pov, king_sq, king_sq)
if p_idx == 11:
p_idx -= 1
return orient(is_white_pov, sq, king_sq) + p_idx * NUM_SQ + KingBuckets[o_ksq] * NUM_PLANES_REAL

def halfka_psqts():
# values copied from stockfish, in stockfish internal units
piece_values = {
chess.PAWN : 126,
chess.KNIGHT : 781,
chess.BISHOP : 825,
chess.ROOK : 1276,
chess.QUEEN : 2538
}

values = [0] * NUM_INPUTS

for ksq in range(64):
for s in range(64):
for pt, val in piece_values.items():
idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE))
idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK))
values[idxw] = val
values[idxb] = -val

return values

class Features(FeatureBlock):
def __init__(self):
super(Features, self).__init__('HalfKAv2_hm', 0x7f234cb8, OrderedDict([('HalfKAv2_hm', NUM_INPUTS)]))

def get_active_features(self, board: chess.Board):
raise Exception('Not supported yet, you must use the c++ data loader for support during training')

def get_initial_psqt_features(self):
return halfka_psqts()

class FactorizedFeatures(FeatureBlock):
def __init__(self):
super(FactorizedFeatures, self).__init__('HalfKAv2_hm^', 0x7f234cb8, OrderedDict([('HalfKAv2_hm', NUM_INPUTS), ('A', NUM_PLANES_VIRTUAL)]))

def get_active_features(self, board: chess.Board):
raise Exception('Not supported yet, you must use the c++ data loader for factorizer support during training')

def get_feature_factors(self, idx):
if idx >= self.num_real_features:
raise Exception('Feature must be real')

a_idx = idx % NUM_PLANES_REAL
k_idx = idx // NUM_PLANES_REAL

if a_idx // NUM_SQ == 10 and k_idx != KingBuckets[a_idx % NUM_SQ]:
a_idx += NUM_SQ

return [idx, self.get_factor_base_feature('A') + a_idx]

def get_initial_psqt_features(self):
return halfka_psqts() + [0] * NUM_PLANES_VIRTUAL

'''
This is used by the features module for discovery of feature blocks.
'''
def get_feature_block_clss():
return [Features, FactorizedFeatures]
103 changes: 103 additions & 0 deletions training_data_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,93 @@ struct HalfKAv2Factorized {
}
};

// ksq must not be oriented
static Square orient_flip_2(Color color, Square sq, Square ksq)
{
bool h = ksq.file() < fileE;
if (color == Color::Black)
sq = sq.flippedVertically();
if (h)
sq = sq.flippedHorizontally();
return sq;
}

struct HalfKAv2_hm {
static constexpr int NUM_SQ = 64;
static constexpr int NUM_PT = 11;
static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
static constexpr int INPUTS = NUM_PLANES * NUM_SQ / 2;

static constexpr int MAX_ACTIVE_FEATURES = 32;

static constexpr int KingBuckets[64] = {
-1, -1, -1, -1, 31, 30, 29, 28,
-1, -1, -1, -1, 27, 26, 25, 24,
-1, -1, -1, -1, 23, 22, 21, 20,
-1, -1, -1, -1, 19, 18, 17, 16,
-1, -1, -1, -1, 15, 14, 13, 12,
-1, -1, -1, -1, 11, 10, 9, 8,
-1, -1, -1, -1, 7, 6, 5, 4,
-1, -1, -1, -1, 3, 2, 1, 0
};

static int feature_index(Color color, Square ksq, Square sq, Piece p)
{
Square o_ksq = orient_flip_2(color, ksq, ksq);
auto p_idx = static_cast<int>(p.type()) * 2 + (p.color() != color);
if (p_idx == 11)
--p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ
return static_cast<int>(orient_flip_2(color, sq, ksq)) + p_idx * NUM_SQ + KingBuckets[static_cast<int>(o_ksq)] * NUM_PLANES;
}

static std::pair<int, int> fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color)
{
auto& pos = e.pos;
auto pieces = pos.piecesBB();
auto ksq = pos.kingSquare(color);

int j = 0;
for(Square sq : pieces)
{
auto p = pos.pieceAt(sq);
values[j] = 1.0f;
features[j] = feature_index(color, ksq, sq, p);
++j;
}

return { j, INPUTS };
}
};

struct HalfKAv2_hmFactorized {
// Factorized features
static constexpr int PIECE_INPUTS = HalfKAv2_hm::NUM_SQ * HalfKAv2_hm::NUM_PT;
static constexpr int INPUTS = HalfKAv2_hm::INPUTS + PIECE_INPUTS;

static constexpr int MAX_PIECE_FEATURES = 32;
static constexpr int MAX_ACTIVE_FEATURES = HalfKAv2_hm::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES;

static std::pair<int, int> fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color)
{
const auto [start_j, offset] = HalfKAv2_hm::fill_features_sparse(e, features, values, color);
auto& pos = e.pos;
auto pieces = pos.piecesBB();
auto ksq = pos.kingSquare(color);

int j = start_j;
for(Square sq : pieces)
{
auto p = pos.pieceAt(sq);
auto p_idx = static_cast<int>(p.type()) * 2 + (p.color() != color);
values[j] = 1.0f;
features[j] = offset + (p_idx * HalfKAv2_hm::NUM_SQ) + static_cast<int>(orient_flip_2(color, sq, ksq));
++j;
}

return { j, INPUTS };
}
};

template <typename T, typename... Ts>
struct FeatureSet
{
Expand Down Expand Up @@ -797,6 +884,14 @@ extern "C" {
{
return new SparseBatch(FeatureSet<HalfKAv2Factorized>{}, entries);
}
else if (feature_set == "HalfKAv2_hm")
{
return new SparseBatch(FeatureSet<HalfKAv2_hm>{}, entries);
}
else if (feature_set == "HalfKAv2_hm^")
{
return new SparseBatch(FeatureSet<HalfKAv2_hmFactorized>{}, entries);
}
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
return nullptr;
}
Expand Down Expand Up @@ -842,6 +937,14 @@ extern "C" {
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
}
else if (feature_set == "HalfKAv2_hm")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hm>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
}
else if (feature_set == "HalfKAv2_hm^")
{
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hmFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
}
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
return nullptr;
}
Expand Down

0 comments on commit 39d379a

Please sign in to comment.