Skip to content

Commit

Permalink
[Chess] Simplify castling attributes (#1245)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Sep 13, 2024
1 parent c54171f commit 753950b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 60 deletions.
51 changes: 14 additions & 37 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,13 @@

class GameState(NamedTuple):
turn: Array = jnp.int32(0)
board: Array = INIT_BOARD # From top left. like FEN
# (curr, opp) Flips every turn
can_castle_queen_side: Array = jnp.ones(2, dtype=jnp.bool_)
can_castle_king_side: Array = jnp.ones(2, dtype=jnp.bool_)
en_passant: Array = jnp.int32(-1) # En passant target. Flips.
# # of moves since the last piece capture or pawn move
halfmove_count: Array = jnp.int32(0)
board: Array = INIT_BOARD
castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king
en_passant: Array = jnp.int32(-1) # En passant target.
halfmove_count: Array = jnp.int32(0) # # of moves since the last piece capture or pawn move
fullmove_count: Array = jnp.int32(1) # increase every black move
# zobrist_hash: Array = INIT_ZOBRIST_HASH
hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH)
board_history: Array = jnp.zeros((8, 64), dtype=jnp.int32).at[0, :].set(INIT_BOARD)
# index to possible piece positions for speeding up. Flips every turn.
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK
step_count: Array = jnp.int32(0)

Expand Down Expand Up @@ -221,10 +216,7 @@ def piece_feat(p):
jax.vmap(make)(jnp.arange(8)).reshape(-1, 8, 8), # board feature
color * ones, # color
(state.step_count / MAX_TERMINATION_STEPS) * ones, # total move count
state.can_castle_queen_side[0] * ones, # my castling right (queen side)
state.can_castle_king_side[0] * ones, # my castling right (king side)
state.can_castle_queen_side[1] * ones, # opp castling right (queen side)
state.can_castle_king_side[1] * ones, # opp castling right (king side)
state.castling_rights.flatten()[:, None, None] * ones, # (my queen, my king, opp queen, opp king)
(state.halfmove_count.astype(jnp.float32) / 100.0) * ones, # no progress count
]
).transpose((1, 2, 0))
Expand Down Expand Up @@ -310,22 +302,8 @@ def _apply_move(state: GameState, a: Action) -> GameState:
board = jax.lax.select(is_king_side_castling, board.at[56].set(EMPTY).at[40].set(ROOK), board)
state = state._replace(board=board)
# update castling rights
state = state._replace(
can_castle_queen_side=state.can_castle_queen_side.at[0].set(
jax.lax.select((a.from_ == 32) | (a.from_ == 0), False, state.can_castle_queen_side[0])
),
can_castle_king_side=state.can_castle_king_side.at[0].set(
jax.lax.select((a.from_ == 32) | (a.from_ == 56), False, state.can_castle_king_side[0])
),
)
state = state._replace(
can_castle_queen_side=state.can_castle_queen_side.at[1].set(
jax.lax.select((a.to == 7), False, state.can_castle_queen_side[1])
),
can_castle_king_side=state.can_castle_king_side.at[1].set(
jax.lax.select((a.to == 63), False, state.can_castle_king_side[1])
),
)
cond = jnp.bool_([[(a.from_ != 32) & (a.from_ != 0), (a.from_ != 32) & (a.from_ != 56)], [a.to != 7, a.to != 63]])
state = state._replace(castling_rights=state.castling_rights & cond)
# promotion to queen
piece = jax.lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, piece)
# underpromotion
Expand All @@ -345,8 +323,7 @@ def _flip(state: GameState) -> GameState:
board=-jnp.flip(state.board.reshape(8, 8), axis=1).flatten(),
turn=(state.turn + 1) % 2,
en_passant=_flip_pos(state.en_passant),
can_castle_queen_side=state.can_castle_queen_side[::-1],
can_castle_king_side=state.can_castle_king_side[::-1],
castling_rights=state.castling_rights[::-1],
board_history=-jnp.flip(state.board_history.reshape(-1, 8, 8), axis=-1).reshape(-1, 64),
)

Expand Down Expand Up @@ -402,9 +379,9 @@ def is_not_checked(label):

# castling
b = state.board
can_castle_queen_side = state.can_castle_queen_side[0]
can_castle_queen_side = state.castling_rights[0, 0]
can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING)
can_castle_king_side = state.can_castle_king_side[0]
can_castle_king_side = state.castling_rights[0, 1]
can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK)
not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48]))
mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all()))
Expand Down Expand Up @@ -447,9 +424,9 @@ def _zobrist_hash(state: GameState) -> Array:
to_reduce = ZOBRIST_BOARD[jnp.arange(64), board + 6]
hash_ ^= jax.lax.reduce(to_reduce, 0, jax.lax.bitwise_xor, (0,))
zero = jnp.uint32([0, 0])
hash_ ^= jax.lax.select(state.can_castle_queen_side[0], ZOBRIST_CASTLING_QUEEN[0], zero)
hash_ ^= jax.lax.select(state.can_castle_queen_side[1], ZOBRIST_CASTLING_QUEEN[1], zero)
hash_ ^= jax.lax.select(state.can_castle_king_side[0], ZOBRIST_CASTLING_KING[0], zero)
hash_ ^= jax.lax.select(state.can_castle_king_side[1], ZOBRIST_CASTLING_KING[1], zero)
hash_ ^= jax.lax.select(state.castling_rights[0, 0], ZOBRIST_CASTLING_QUEEN[0], zero)
hash_ ^= jax.lax.select(state.castling_rights[0, 1], ZOBRIST_CASTLING_QUEEN[1], zero)
hash_ ^= jax.lax.select(state.castling_rights[1, 0], ZOBRIST_CASTLING_KING[0], zero)
hash_ ^= jax.lax.select(state.castling_rights[1, 1], ZOBRIST_CASTLING_KING[1], zero)
hash_ ^= ZOBRIST_EN_PASSANT[state.en_passant]
return hash_
33 changes: 10 additions & 23 deletions pgx/experimental/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,9 @@ def from_fen(fen: str):
if str.islower(c):
ix *= -1
arr.append(ix)
can_castle_queen_side = jnp.zeros(2, dtype=jnp.bool_)
can_castle_king_side = jnp.zeros(2, dtype=jnp.bool_)
if "Q" in castling:
can_castle_queen_side = can_castle_queen_side.at[0].set(TRUE)
if "q" in castling:
can_castle_queen_side = can_castle_queen_side.at[1].set(TRUE)
if "K" in castling:
can_castle_king_side = can_castle_king_side.at[0].set(TRUE)
if "k" in castling:
can_castle_king_side = can_castle_king_side.at[1].set(TRUE)
castling_rights = jnp.bool_([["Q" in castling, "K" in castling], ["q" in castling, "k" in castling]])
if turn == "b":
can_castle_queen_side = can_castle_queen_side[::-1]
can_castle_king_side = can_castle_king_side[::-1]
castling_rights = castling_rights[::-1]
mat = jnp.int32(arr).reshape(8, 8)
if turn == "b":
mat = -jnp.flip(mat, axis=0)
Expand All @@ -71,8 +61,7 @@ def from_fen(fen: str):
x = GameState(
board=jnp.rot90(mat, k=3).flatten(),
turn=jnp.int32(0) if turn == "w" else jnp.int32(1),
can_castle_queen_side=can_castle_queen_side,
can_castle_king_side=can_castle_king_side,
castling_rights=castling_rights,
en_passant=ep,
halfmove_count=jnp.int32(halfmove_cnt),
fullmove_count=jnp.int32(fullmove_cnt),
Expand Down Expand Up @@ -142,21 +131,19 @@ def to_fen(state: State):
# turn
fen += "w " if state._x.turn == 0 else "b "
# castling
can_castle_queen_side = state._x.can_castle_queen_side
can_castle_king_side = state._x.can_castle_king_side
castling_rights = state._x.castling_rights
if state._x.turn == 1:
can_castle_queen_side = can_castle_queen_side[::-1]
can_castle_king_side = can_castle_king_side[::-1]
if not (np.any(can_castle_queen_side) | np.any(can_castle_king_side)):
castling_rights = castling_rights[::-1]
if not (np.any(castling_rights)):
fen += "-"
else:
if can_castle_king_side[0]:
if castling_rights[0, 1]:
fen += "K"
if can_castle_queen_side[0]:
if castling_rights[0, 0]:
fen += "Q"
if can_castle_king_side[1]:
if castling_rights[1, 1]:
fen += "k"
if can_castle_queen_side[1]:
if castling_rights[1, 0]:
fen += "q"
fen += " "
# em passant
Expand Down

0 comments on commit 753950b

Please sign in to comment.