Skip to content

Commit

Permalink
[Chess] Tidy (#1247)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Sep 13, 2024
1 parent 624f645 commit c7274bd
Showing 1 changed file with 41 additions and 52 deletions.
93 changes: 41 additions & 52 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from jax import Array

# fmt: off
EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # * -1 for opponent
EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece

# board index:
INIT_BOARD = jnp.int32([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4])
# 8 7 15 23 31 39 47 55 63
# 7 6 14 22 30 38 46 54 62
# 6 5 13 21 29 37 45 53 61
Expand All @@ -32,30 +32,28 @@
# 2 1 9 17 25 33 41 49 57
# 1 0 8 16 24 32 40 48 56
# a b c d e f g h
INIT_BOARD = jnp.int32([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4])

# Action
# We use AlphaZero style label with channel-last representation: (8, 8, 73)
# Action: AlphaZero style label (4672 = 64 x 73)
# 73 = underpromotions (3 * 3) + queen moves (56) + knight moves (8)
# 0 ~ 8: underpromotions
# * 0 ~ 8: underpromotions
# plane // 3 == 0: rook, 1: bishop, 2: knight
# plane % 3 == 0: up , 1: right, 2: left
# 9 ~ 72: normal moves (queen + knight)
# 51 22 50
# 52 21 49
# 53 20 48
# 54 19 47
# 55 18 46
# 56 17 45
# 57 16 44
# 23 24 25 26 27 28 29 X 30 31 32 33 34 35 36
# 43 15 58
# 42 14 59
# 41 13 60
# 40 12 61
# 39 11 62
# 38 10 64
# 37 9 64
# * 9 ~ 72: normal moves (queen + knight)
# 51 22 50
# 52 21 49
# 53 20 48
# 54 19 47
# 55 18 46
# 56 17 45
# 57 16 44
# 23 24 25 26 27 28 29 X 30 31 32 33 34 35 36
# 43 15 58
# 42 14 59
# 41 13 60
# 40 12 61
# 39 11 62
# 38 10 64
# 37 9 64
FROM_PLANE = -np.ones((64, 73), dtype=np.int32)
TO_PLANE = -np.ones((64, 64), dtype=np.int32) # ignores underpromotion
zeros, seq, rseq = [0] * 7, list(range(1, 8)), list(range(-7, 0))
Expand Down Expand Up @@ -135,16 +133,16 @@
ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862])

MAX_TERMINATION_STEPS = 512 # from AZ paper
MAX_TERMINATION_STEPS = 512 # from AlphaZero paper
# fmt: on


class GameState(NamedTuple):
turn: Array = jnp.int32(0)
board: Array = INIT_BOARD
castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king
en_passant: Array = jnp.int32(-1) # En passant target.
halfmove_count: Array = jnp.int32(0) # # of moves since the last piece capture or pawn move
en_passant: Array = jnp.int32(-1)
halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move
fullmove_count: Array = jnp.int32(1) # increase every black move
hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH)
board_history: Array = jnp.zeros((8, 64), dtype=jnp.int32).at[0, :].set(INIT_BOARD)
Expand Down Expand Up @@ -233,33 +231,27 @@ def rewards(self, state: GameState) -> Array:


def _update_history(state: GameState):
# board history
board_history = jnp.roll(state.board_history, 64)
board_history = board_history.at[0].set(state.board)
state = state._replace(board_history=board_history)
# hash hist
hash_hist = jnp.roll(state.hash_history, 2)
hash_hist = hash_hist.at[0].set(_zobrist_hash(state))
state = state._replace(hash_history=hash_hist)
return state
return state._replace(board_history=board_history, hash_history=hash_hist)


def has_insufficient_pieces(state: GameState):
# Uses the same condition as OpenSpiel.
# See https://github.com/deepmind/open_spiel/blob/master/open_spiel/games/chess/chess_board.cc#L724
# uses the same condition as OpenSpiel
num_pieces = (state.board != EMPTY).sum()
num_pawn_rook_queen = ((jnp.abs(state.board) >= ROOK) | (jnp.abs(state.board) == PAWN)).sum() - 2 # two kings
num_bishop = (jnp.abs(state.board) == 3).sum()
num_bishop = (jnp.abs(state.board) == BISHOP).sum()
coords = jnp.arange(64).reshape((8, 8))
# [ 0 2 4 6 16 18 20 22 32 34 36 38 48 50 52 54 9 11 13 15 25 27 29 31 41 43 45 47 57 59 61 63]
black_coords = jnp.hstack((coords[::2, ::2].ravel(), coords[1::2, 1::2].ravel()))
num_bishop_on_black = (jnp.abs(state.board[black_coords]) == BISHOP).sum()
is_insufficient = False
# King vs King
# king vs king
is_insufficient |= num_pieces <= 2
# King + X vs King. X == KNIGHT or BISHOP
# king vs king + (knight or bishop)
is_insufficient |= (num_pieces == 3) & (num_pawn_rook_queen == 0)
# King + Bishop* vs King + Bishop* (Bishops are on same color tile)
# king + bishop* vs king + bishop* (bishops are on same color tile)
is_bishop_all_on_black = num_bishop_on_black == num_bishop
is_bishop_all_on_white = num_bishop_on_black == 0
is_insufficient |= (num_pieces == num_bishop + 2) & (is_bishop_all_on_black | is_bishop_all_on_white)
Expand All @@ -268,7 +260,6 @@ def has_insufficient_pieces(state: GameState):


def _apply_move(state: GameState, a: Action) -> GameState:
# apply move action
piece = state.board[a.from_]
# en passant
is_en_passant = (state.en_passant >= 0) & (piece == PAWN) & (state.en_passant == a.to)
Expand Down Expand Up @@ -303,8 +294,7 @@ def _apply_move(state: GameState, a: Action) -> GameState:
return state


def _flip_pos(x):
# e.g., 37 <-> 34, -1 <-> -1
def _flip_pos(x): # e.g., 37 <-> 34, -1 <-> -1
return jax.lax.select(x == -1, x, (x // 8) * 8 + (7 - (x % 8)))


Expand Down Expand Up @@ -333,17 +323,6 @@ def legal_label(to):

return jax.vmap(legal_label)(LEGAL_DEST[piece, from_])

def legal_underpromotions(mask):
def legal_labels(label):
a = Action._from_label(label)
ok = (state.board[a.from_] == PAWN) & (a.to >= 0)
ok &= mask[Action(from_=a.from_, to=a.to)._to_label()]
return jax.lax.select(ok, label, -1)

# from_ = 6 14 ... 62, plane = 0 1 ... 8
labels = jnp.int32([from_ * 73 + i for i in range(9) for from_ in [6, 14, 22, 30, 38, 46, 54, 62]])
return jax.vmap(legal_labels)(labels)

def legal_en_passants():
to = state.en_passant

Expand All @@ -358,6 +337,16 @@ def is_not_checked(label):
a = Action._from_label(label)
return ~_is_checked(_apply_move(state, a))

def legal_underpromotions(mask):
def legal_labels(label):
a = Action._from_label(label)
ok = (state.board[a.from_] == PAWN) & (a.to >= 0)
ok &= mask[Action(from_=a.from_, to=a.to)._to_label()]
return jax.lax.select(ok, label, -1)

labels = jnp.int32([from_ * 73 + i for i in range(9) for from_ in [6, 14, 22, 30, 38, 46, 54, 62]])
return jax.vmap(legal_labels)(labels)

# normal move and en passant
possible_piece_positions = jnp.nonzero(state.board > 0, size=16, fill_value=-1)[0]
a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten()
Expand Down Expand Up @@ -390,7 +379,7 @@ def can_move(to):
piece = jnp.abs(state.board[to])
between_ixs = BETWEEN[pos, to]
ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
ok &= ~((piece == PAWN) & (to // 8 == pos // 8)) # should move diagnally to capture the king
ok &= ~((piece == PAWN) & (to // 8 == pos // 8)) # should move diagonally to capture
return ok

return jax.vmap(can_move)(LEGAL_DEST_ANY[pos, :]).any()
Expand Down

0 comments on commit c7274bd

Please sign in to comment.