Skip to content

Commit

Permalink
[Chess] Remove _is_pseudo_legal (#1237)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Sep 12, 2024
1 parent 139e61a commit 34c7cb7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 26 deletions.
32 changes: 9 additions & 23 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,18 @@ def _flip(state: GameState) -> GameState:

def _legal_action_mask(state: GameState) -> Array:
@jax.vmap
def legal_norml_moves(from_):
def legal_normal_moves(from_):
piece = state.board[from_]

@jax.vmap
def legal_label(to):
a = Action(from_=from_, to=to)
ok = (from_ >= 0) & (piece > 0) & (to >= 0) & _is_pseudo_legal(state, a)
return jax.lax.select(ok, a._to_label(), -1)
ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (state.board[to] <= 0)
between_ixs = BETWEEN[from_, to]
ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
c0, c1 = from_ // 8, to // 8
pawn_should = ((c1 == c0) & (state.board[to] == EMPTY)) | ((c1 != c0) & (state.board[to] < 0))
ok &= (piece != PAWN) | pawn_should
return jax.lax.select(ok, Action(from_=from_, to=to)._to_label(), -1)

return legal_label(LEGAL_DEST[piece, from_])

Expand Down Expand Up @@ -449,7 +453,7 @@ def is_not_checked(label):

# normal move and en passant
possible_piece_positions = jnp.nonzero(state.board > 0, size=16, fill_value=-1)[0].astype(jnp.int32)
a1 = legal_norml_moves(possible_piece_positions).flatten()
a1 = legal_normal_moves(possible_piece_positions).flatten()
a2 = legal_en_passants()
actions = jnp.hstack((a1, a2)) # include -1
actions = jnp.where(is_not_checked(actions), actions, -1)
Expand Down Expand Up @@ -484,7 +488,6 @@ def can_move(to):
between_ixs = BETWEEN[pos, to]
ok &= ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
# For pawn, it should be the forward diagonal direction
ok &= ~((piece == PAWN) & ((to % 8) < (pos % 8))) # should move forward
ok &= ~((piece == PAWN) & (to // 8 == pos // 8)) # should move diagnally to capture the king
return ok

Expand All @@ -497,23 +500,6 @@ def _is_checked(state: GameState):
return _is_attacked(state, king_pos)


def _is_pseudo_legal(state: GameState, a: Action):
piece = state.board[a.from_]
ok = (piece >= 0) & (state.board[a.to] <= 0)
ok &= CAN_MOVE[piece, a.from_, a.to]
between_ixs = BETWEEN[a.from_, a.to]
ok &= ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
# filter pawn move
ok &= ~((piece == PAWN) & ((a.to % 8) < (a.from_ % 8))) # should move forward
ok &= ~(
(piece == PAWN) & (jnp.abs(a.to - a.from_) <= 2) & (state.board[a.to] < 0)
) # cannot move up if occupied by opponent
ok &= ~(
(piece == PAWN) & (jnp.abs(a.to - a.from_) > 2) & (state.board[a.to] >= 0)
) # cannot move diagnally without capturing
return (a.to >= 0) & ok


def _zobrist_hash(state: GameState) -> Array:
"""
>>> state = GameState()
Expand Down
6 changes: 3 additions & 3 deletions pgx/_src/games/chess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
# Note that the board is not symmetric about the center (different from shogi)
# You can imagine that the viewpoint is always from the white side.
# Except PAWN, the moves are symmetric about the center.
# We define PAWN as a piece that can move up, down, and diagonally, and filter it according to the turn.
# We define PAWN as a piece that can move left-up, up, right-up.


# PAWN
Expand All @@ -81,10 +81,10 @@
legal_dst = []
for to in range(64):
r1, c1 = to % 8, to // 8
if np.abs(r1 - r0) == 1 and np.abs(c1 - c0) <= 1:
if r1 - r0 == 1 and np.abs(c1 - c0) <= 1:
legal_dst.append(to)
# init move
if (r0 == 1 or r0 == 6) and (np.abs(c1 - c0) == 0 and np.abs(r1 - r0) == 2):
if r0 == 1 and r1 == 3 and np.abs(c1 - c0) == 0:
legal_dst.append(to)
assert len(legal_dst) <= 8
LEGAL_DEST[1, from_, : len(legal_dst)] = legal_dst
Expand Down

0 comments on commit 34c7cb7

Please sign in to comment.