Skip to content

Commit

Permalink
[Chess] Tidy _apply_move (#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Sep 12, 2024
1 parent 34c7cb7 commit 8246434
Showing 1 changed file with 17 additions and 54 deletions.
71 changes: 17 additions & 54 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,10 @@ def _apply_move(state: GameState, a: Action) -> GameState:
state = state._replace(
board=state.board.at[removed_pawn_pos].set(jax.lax.select(is_en_passant, EMPTY, state.board[removed_pawn_pos]))
)
state = state._replace(
en_passant=jax.lax.select(
(piece == PAWN) & (jnp.abs(a.to - a.from_) == 2),
jnp.int32((a.to + a.from_) // 2),
jnp.int32(-1),
)
)
is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2)
state = state._replace(en_passant=jax.lax.select(is_en_passant, (a.to + a.from_) // 2, -1))
# update counters
captured = (state.board[a.to] < 0) | (is_en_passant)
captured = (state.board[a.to] < 0) | is_en_passant
state = state._replace(
halfmove_count=jax.lax.select(captured | (piece == PAWN), 0, state.halfmove_count + 1),
fullmove_count=state.fullmove_count + jnp.int32(state.turn == 1),
Expand All @@ -313,67 +308,35 @@ def _apply_move(state: GameState, a: Action) -> GameState:
# Whether castling is possible or not is not checked here.
# We assume that if castling is not possible, it is filtered out.
# left
state = state._replace(
board=jax.lax.cond(
(piece == KING) & (a.from_ == 32) & (a.to == 16),
lambda: state.board.at[0].set(EMPTY).at[24].set(ROOK),
lambda: state.board,
),
)
board = state.board
is_queen_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 16)
board = jax.lax.select(is_queen_side_castling, board.at[0].set(EMPTY).at[24].set(ROOK), board)
# right
state = state._replace(
board=jax.lax.cond(
(piece == KING) & (a.from_ == 32) & (a.to == 48),
lambda: state.board.at[56].set(EMPTY).at[40].set(ROOK),
lambda: state.board,
),
)
# update my can_castle_xxx_side
is_king_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 48)
board = jax.lax.select(is_king_side_castling, board.at[56].set(EMPTY).at[40].set(ROOK), board)
state = state._replace(board=board)
# update castling rights
state = state._replace(
can_castle_queen_side=state.can_castle_queen_side.at[0].set(
jax.lax.select(
(a.from_ == 32) | (a.from_ == 0),
FALSE,
state.can_castle_queen_side[0],
)
jax.lax.select((a.from_ == 32) | (a.from_ == 0), FALSE, state.can_castle_queen_side[0])
),
can_castle_king_side=state.can_castle_king_side.at[0].set(
jax.lax.select(
(a.from_ == 32) | (a.from_ == 56),
FALSE,
state.can_castle_king_side[0],
)
jax.lax.select((a.from_ == 32) | (a.from_ == 56), FALSE, state.can_castle_king_side[0])
),
)
# update opp can_castle_xxx_side
# update opp castling rights
state = state._replace(
can_castle_queen_side=state.can_castle_queen_side.at[1].set(
jax.lax.select(
(a.to == 7),
FALSE,
state.can_castle_queen_side[1],
)
jax.lax.select((a.to == 7), FALSE, state.can_castle_queen_side[1])
),
can_castle_king_side=state.can_castle_king_side.at[1].set(
jax.lax.select(
(a.to == 63),
FALSE,
state.can_castle_king_side[1],
)
jax.lax.select((a.to == 63), FALSE, state.can_castle_king_side[1])
),
)
# promotion to queen
piece = jax.lax.select(
piece == PAWN & (a.from_ % 8 == 6) & (a.underpromotion < 0),
QUEEN,
piece,
)
piece = jax.lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, piece)
# underpromotion
piece = jax.lax.select(
a.underpromotion < 0,
piece,
jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion],
)
piece = jax.lax.select(a.underpromotion < 0, piece, jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion])
# actually move
state = state._replace(board=state.board.at[a.from_].set(EMPTY).at[a.to].set(piece)) # type: ignore
return state
Expand Down

0 comments on commit 8246434

Please sign in to comment.