diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index e86ae79c2..6a8c7a24b 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -296,15 +296,10 @@ def _apply_move(state: GameState, a: Action) -> GameState: state = state._replace( board=state.board.at[removed_pawn_pos].set(jax.lax.select(is_en_passant, EMPTY, state.board[removed_pawn_pos])) ) - state = state._replace( - en_passant=jax.lax.select( - (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2), - jnp.int32((a.to + a.from_) // 2), - jnp.int32(-1), - ) - ) + is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2) + state = state._replace(en_passant=jax.lax.select(is_en_passant, (a.to + a.from_) // 2, -1)) # update counters - captured = (state.board[a.to] < 0) | (is_en_passant) + captured = (state.board[a.to] < 0) | is_en_passant state = state._replace( halfmove_count=jax.lax.select(captured | (piece == PAWN), 0, state.halfmove_count + 1), fullmove_count=state.fullmove_count + jnp.int32(state.turn == 1), @@ -313,67 +308,35 @@ def _apply_move(state: GameState, a: Action) -> GameState: # Whether castling is possible or not is not checked here. # We assume that if castling is not possible, it is filtered out. # left - state = state._replace( - board=jax.lax.cond( - (piece == KING) & (a.from_ == 32) & (a.to == 16), - lambda: state.board.at[0].set(EMPTY).at[24].set(ROOK), - lambda: state.board, - ), - ) + board = state.board + is_queen_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 16) + board = jax.lax.select(is_queen_side_castling, board.at[0].set(EMPTY).at[24].set(ROOK), board) # right - state = state._replace( - board=jax.lax.cond( - (piece == KING) & (a.from_ == 32) & (a.to == 48), - lambda: state.board.at[56].set(EMPTY).at[40].set(ROOK), - lambda: state.board, - ), - ) - # update my can_castle_xxx_side + is_king_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 48) + board = jax.lax.select(is_king_side_castling, board.at[56].set(EMPTY).at[40].set(ROOK), board) + state = state._replace(board=board) + # update castling rights state = state._replace( can_castle_queen_side=state.can_castle_queen_side.at[0].set( - jax.lax.select( - (a.from_ == 32) | (a.from_ == 0), - FALSE, - state.can_castle_queen_side[0], - ) + jax.lax.select((a.from_ == 32) | (a.from_ == 0), FALSE, state.can_castle_queen_side[0]) ), can_castle_king_side=state.can_castle_king_side.at[0].set( - jax.lax.select( - (a.from_ == 32) | (a.from_ == 56), - FALSE, - state.can_castle_king_side[0], - ) + jax.lax.select((a.from_ == 32) | (a.from_ == 56), FALSE, state.can_castle_king_side[0]) ), ) - # update opp can_castle_xxx_side + # update opp castling rights state = state._replace( can_castle_queen_side=state.can_castle_queen_side.at[1].set( - jax.lax.select( - (a.to == 7), - FALSE, - state.can_castle_queen_side[1], - ) + jax.lax.select((a.to == 7), FALSE, state.can_castle_queen_side[1]) ), can_castle_king_side=state.can_castle_king_side.at[1].set( - jax.lax.select( - (a.to == 63), - FALSE, - state.can_castle_king_side[1], - ) + jax.lax.select((a.to == 63), FALSE, state.can_castle_king_side[1]) ), ) # promotion to queen - piece = jax.lax.select( - piece == PAWN & (a.from_ % 8 == 6) & (a.underpromotion < 0), - QUEEN, - piece, - ) + piece = jax.lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, piece) # underpromotion - piece = jax.lax.select( - a.underpromotion < 0, - piece, - jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion], - ) + piece = jax.lax.select(a.underpromotion < 0, piece, jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion]) # actually move state = state._replace(board=state.board.at[a.from_].set(EMPTY).at[a.to].set(piece)) # type: ignore return state