Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chess] Simplify castling attributes #1245

Merged
merged 7 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 14 additions & 37 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,13 @@

class GameState(NamedTuple):
turn: Array = jnp.int32(0)
board: Array = INIT_BOARD # From top left. like FEN
# (curr, opp) Flips every turn
can_castle_queen_side: Array = jnp.ones(2, dtype=jnp.bool_)
can_castle_king_side: Array = jnp.ones(2, dtype=jnp.bool_)
en_passant: Array = jnp.int32(-1) # En passant target. Flips.
# # of moves since the last piece capture or pawn move
halfmove_count: Array = jnp.int32(0)
board: Array = INIT_BOARD
castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king
en_passant: Array = jnp.int32(-1) # En passant target.
halfmove_count: Array = jnp.int32(0) # # of moves since the last piece capture or pawn move
fullmove_count: Array = jnp.int32(1) # increase every black move
# zobrist_hash: Array = INIT_ZOBRIST_HASH
hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH)
board_history: Array = jnp.zeros((8, 64), dtype=jnp.int32).at[0, :].set(INIT_BOARD)
# index to possible piece positions for speeding up. Flips every turn.
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK
step_count: Array = jnp.int32(0)

Expand Down Expand Up @@ -221,10 +216,7 @@ def piece_feat(p):
jax.vmap(make)(jnp.arange(8)).reshape(-1, 8, 8), # board feature
color * ones, # color
(state.step_count / MAX_TERMINATION_STEPS) * ones, # total move count
state.can_castle_queen_side[0] * ones, # my castling right (queen side)
state.can_castle_king_side[0] * ones, # my castling right (king side)
state.can_castle_queen_side[1] * ones, # opp castling right (queen side)
state.can_castle_king_side[1] * ones, # opp castling right (king side)
state.castling_rights.flatten()[:, None, None] * ones, # (my queen, my king, opp queen, opp king)
(state.halfmove_count.astype(jnp.float32) / 100.0) * ones, # no progress count
]
).transpose((1, 2, 0))
Expand Down Expand Up @@ -310,22 +302,8 @@ def _apply_move(state: GameState, a: Action) -> GameState:
board = jax.lax.select(is_king_side_castling, board.at[56].set(EMPTY).at[40].set(ROOK), board)
state = state._replace(board=board)
# update castling rights
state = state._replace(
can_castle_queen_side=state.can_castle_queen_side.at[0].set(
jax.lax.select((a.from_ == 32) | (a.from_ == 0), False, state.can_castle_queen_side[0])
),
can_castle_king_side=state.can_castle_king_side.at[0].set(
jax.lax.select((a.from_ == 32) | (a.from_ == 56), False, state.can_castle_king_side[0])
),
)
state = state._replace(
can_castle_queen_side=state.can_castle_queen_side.at[1].set(
jax.lax.select((a.to == 7), False, state.can_castle_queen_side[1])
),
can_castle_king_side=state.can_castle_king_side.at[1].set(
jax.lax.select((a.to == 63), False, state.can_castle_king_side[1])
),
)
cond = jnp.bool_([[(a.from_ != 32) & (a.from_ != 0), (a.from_ != 32) & (a.from_ != 56)], [a.to != 7, a.to != 63]])
state = state._replace(castling_rights=state.castling_rights & cond)
# promotion to queen
piece = jax.lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, piece)
# underpromotion
Expand All @@ -345,8 +323,7 @@ def _flip(state: GameState) -> GameState:
board=-jnp.flip(state.board.reshape(8, 8), axis=1).flatten(),
turn=(state.turn + 1) % 2,
en_passant=_flip_pos(state.en_passant),
can_castle_queen_side=state.can_castle_queen_side[::-1],
can_castle_king_side=state.can_castle_king_side[::-1],
castling_rights=state.castling_rights[::-1],
board_history=-jnp.flip(state.board_history.reshape(-1, 8, 8), axis=-1).reshape(-1, 64),
)

Expand Down Expand Up @@ -402,9 +379,9 @@ def is_not_checked(label):

# castling
b = state.board
can_castle_queen_side = state.can_castle_queen_side[0]
can_castle_queen_side = state.castling_rights[0, 0]
can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING)
can_castle_king_side = state.can_castle_king_side[0]
can_castle_king_side = state.castling_rights[0, 1]
can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK)
not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48]))
mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all()))
Expand Down Expand Up @@ -447,9 +424,9 @@ def _zobrist_hash(state: GameState) -> Array:
to_reduce = ZOBRIST_BOARD[jnp.arange(64), board + 6]
hash_ ^= jax.lax.reduce(to_reduce, 0, jax.lax.bitwise_xor, (0,))
zero = jnp.uint32([0, 0])
hash_ ^= jax.lax.select(state.can_castle_queen_side[0], ZOBRIST_CASTLING_QUEEN[0], zero)
hash_ ^= jax.lax.select(state.can_castle_queen_side[1], ZOBRIST_CASTLING_QUEEN[1], zero)
hash_ ^= jax.lax.select(state.can_castle_king_side[0], ZOBRIST_CASTLING_KING[0], zero)
hash_ ^= jax.lax.select(state.can_castle_king_side[1], ZOBRIST_CASTLING_KING[1], zero)
hash_ ^= jax.lax.select(state.castling_rights[0, 0], ZOBRIST_CASTLING_QUEEN[0], zero)
hash_ ^= jax.lax.select(state.castling_rights[0, 1], ZOBRIST_CASTLING_QUEEN[1], zero)
hash_ ^= jax.lax.select(state.castling_rights[1, 0], ZOBRIST_CASTLING_KING[0], zero)
hash_ ^= jax.lax.select(state.castling_rights[1, 1], ZOBRIST_CASTLING_KING[1], zero)
hash_ ^= ZOBRIST_EN_PASSANT[state.en_passant]
return hash_
33 changes: 10 additions & 23 deletions pgx/experimental/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,9 @@ def from_fen(fen: str):
if str.islower(c):
ix *= -1
arr.append(ix)
can_castle_queen_side = jnp.zeros(2, dtype=jnp.bool_)
can_castle_king_side = jnp.zeros(2, dtype=jnp.bool_)
if "Q" in castling:
can_castle_queen_side = can_castle_queen_side.at[0].set(TRUE)
if "q" in castling:
can_castle_queen_side = can_castle_queen_side.at[1].set(TRUE)
if "K" in castling:
can_castle_king_side = can_castle_king_side.at[0].set(TRUE)
if "k" in castling:
can_castle_king_side = can_castle_king_side.at[1].set(TRUE)
castling_rights = jnp.bool_([["Q" in castling, "K" in castling], ["q" in castling, "k" in castling]])
if turn == "b":
can_castle_queen_side = can_castle_queen_side[::-1]
can_castle_king_side = can_castle_king_side[::-1]
castling_rights = castling_rights[::-1]
mat = jnp.int32(arr).reshape(8, 8)
if turn == "b":
mat = -jnp.flip(mat, axis=0)
Expand All @@ -71,8 +61,7 @@ def from_fen(fen: str):
x = GameState(
board=jnp.rot90(mat, k=3).flatten(),
turn=jnp.int32(0) if turn == "w" else jnp.int32(1),
can_castle_queen_side=can_castle_queen_side,
can_castle_king_side=can_castle_king_side,
castling_rights=castling_rights,
en_passant=ep,
halfmove_count=jnp.int32(halfmove_cnt),
fullmove_count=jnp.int32(fullmove_cnt),
Expand Down Expand Up @@ -142,21 +131,19 @@ def to_fen(state: State):
# turn
fen += "w " if state._x.turn == 0 else "b "
# castling
can_castle_queen_side = state._x.can_castle_queen_side
can_castle_king_side = state._x.can_castle_king_side
castling_rights = state._x.castling_rights
if state._x.turn == 1:
can_castle_queen_side = can_castle_queen_side[::-1]
can_castle_king_side = can_castle_king_side[::-1]
if not (np.any(can_castle_queen_side) | np.any(can_castle_king_side)):
castling_rights = castling_rights[::-1]
if not (np.any(castling_rights)):
fen += "-"
else:
if can_castle_king_side[0]:
if castling_rights[0, 1]:
fen += "K"
if can_castle_queen_side[0]:
if castling_rights[0, 0]:
fen += "Q"
if can_castle_king_side[1]:
if castling_rights[1, 1]:
fen += "k"
if can_castle_queen_side[1]:
if castling_rights[1, 0]:
fen += "q"
fen += " "
# em passant
Expand Down
Loading