Skip to content

Commit

Permalink
[Chess] Simplify hash implementation (#1246)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Sep 13, 2024
1 parent 753950b commit 624f645
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 61 deletions.
45 changes: 12 additions & 33 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,12 @@
jnp.array(x) for x in (FROM_PLANE, TO_PLANE, LEGAL_DEST, LEGAL_DEST_ANY, CAN_MOVE, BETWEEN, INIT_LEGAL_ACTION_MASK)
)

key = jax.random.PRNGKey(238290)
key, subkey = jax.random.split(key)
ZOBRIST_BOARD = jax.random.randint(subkey, shape=(64, 13, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
key, subkey = jax.random.split(key)
ZOBRIST_SIDE = jax.random.randint(subkey, shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
key, subkey = jax.random.split(key)
ZOBRIST_CASTLING_QUEEN = jax.random.randint(subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
key, subkey = jax.random.split(key)
ZOBRIST_CASTLING_KING = jax.random.randint(subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
key, subkey = jax.random.split(key)
ZOBRIST_EN_PASSANT = jax.random.randint(subkey, shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
INIT_ZOBRIST_HASH = jnp.uint32([1172276016, 1112364556])
keys = jax.random.split(jax.random.PRNGKey(12345), 4)
ZOBRIST_BOARD = jax.random.randint(keys[0], shape=(64, 13, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
ZOBRIST_SIDE = jax.random.randint(keys[1], shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
ZOBRIST_CASTLING = jax.random.randint(keys[2], shape=(4, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862])

MAX_TERMINATION_STEPS = 512 # from AZ paper
# fmt: on
Expand All @@ -157,10 +151,6 @@ class GameState(NamedTuple):
legal_action_mask: Array = INIT_LEGAL_ACTION_MASK
step_count: Array = jnp.int32(0)

@property
def zobrist_hash(self):
return _zobrist_hash(self)


class Action(NamedTuple):
from_: Array = jnp.int32(-1)
Expand Down Expand Up @@ -228,7 +218,7 @@ def is_terminal(self, state: GameState) -> Array:
terminated = ~state.legal_action_mask.any()
terminated |= state.halfmove_count >= 100
terminated |= has_insufficient_pieces(state)
rep = (state.hash_history == state.zobrist_hash).all(axis=1).sum() - 1
rep = (state.hash_history == _zobrist_hash(state)).all(axis=1).sum() - 1
terminated |= rep >= 2
terminated |= MAX_TERMINATION_STEPS <= state.step_count
return terminated
Expand All @@ -249,7 +239,7 @@ def _update_history(state: GameState):
state = state._replace(board_history=board_history)
# hash hist
hash_hist = jnp.roll(state.hash_history, 2)
hash_hist = hash_hist.at[0].set(state.zobrist_hash)
hash_hist = hash_hist.at[0].set(_zobrist_hash(state))
state = state._replace(hash_history=hash_hist)
return state

Expand Down Expand Up @@ -412,21 +402,10 @@ def _is_checked(state: GameState):


def _zobrist_hash(state: GameState) -> Array:
"""
>>> state = GameState()
>>> _zobrist_hash(state)
Array([1172276016, 1112364556], dtype=uint32)
"""
hash_ = jnp.zeros(2, dtype=jnp.uint32)
hash_ = jax.lax.select(state.turn == 0, hash_, hash_ ^ ZOBRIST_SIDE)
board = jax.lax.select(state.turn == 0, state.board, _flip(state).board)
# 0, ..., 12 (white pawn, ..., black king)
to_reduce = ZOBRIST_BOARD[jnp.arange(64), board + 6]
hash_ = jax.lax.select(state.turn == 0, ZOBRIST_SIDE, jnp.zeros_like(ZOBRIST_SIDE))
to_reduce = ZOBRIST_BOARD[jnp.arange(64), state.board + 6] # 0, ..., 12 (w:pawn, ..., b:king)
hash_ ^= jax.lax.reduce(to_reduce, 0, jax.lax.bitwise_xor, (0,))
to_reduce = jnp.where(state.castling_rights.reshape(-1, 1), ZOBRIST_CASTLING, 0)
hash_ ^= jax.lax.reduce(to_reduce, 0, jax.lax.bitwise_xor, (0,))
zero = jnp.uint32([0, 0])
hash_ ^= jax.lax.select(state.castling_rights[0, 0], ZOBRIST_CASTLING_QUEEN[0], zero)
hash_ ^= jax.lax.select(state.castling_rights[0, 1], ZOBRIST_CASTLING_QUEEN[1], zero)
hash_ ^= jax.lax.select(state.castling_rights[1, 0], ZOBRIST_CASTLING_KING[0], zero)
hash_ ^= jax.lax.select(state.castling_rights[1, 1], ZOBRIST_CASTLING_KING[1], zero)
hash_ ^= ZOBRIST_EN_PASSANT[state.en_passant]
return hash_
37 changes: 9 additions & 28 deletions tests/test_chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
import pgx
from pgx.chess import State, Chess
from pgx._src.games.chess import GameState, Action, KING, QUEEN, EMPTY, ROOK, PAWN, _legal_action_mask, CAN_MOVE, _zobrist_hash
from pgx._src.games.chess import GameState, Action, KING, QUEEN, EMPTY, ROOK, PAWN, _legal_action_mask, CAN_MOVE, _zobrist_hash, INIT_ZOBRIST_HASH
from pgx.experimental.utils import act_randomly
from pgx.experimental.chess import from_fen, to_fen

Expand All @@ -26,19 +26,6 @@ def p(s: str, b=False):
return x * 8 + offset


def test_zobrist_hash():
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
state = init(subkey)
assert (state._x.zobrist_hash == jax.jit(_zobrist_hash)(state._x)).all()
# for i in range(5):
while not state.terminated:
key, subkey = jax.random.split(key)
action = act_randomly(subkey, state.legal_action_mask)
state = step(state, action)
# state.save_svg("debug.svg")
assert (state._x.zobrist_hash == jax.jit(_zobrist_hash)(state._x)).all()

def test_action():
# See #704
# queen moves
Expand Down Expand Up @@ -717,20 +704,6 @@ def test_buggy_samples():
expected_legal_actions = [16, 30, 162, 1212, 1225, 1269, 1270, 1271, 1272, 1284, 1297, 1298, 1299, 1942, 1943, 1956, 2498, 2948, 2949, 3131, 3132, 3133, 3134, 3135, 3136, 3138, 3534, 3548, 3593, 3594, 4250]
assert state.legal_action_mask.sum() == len(expected_legal_actions), f"\nactual:{jnp.nonzero(state.legal_action_mask)[0]}\nexpected\n{expected_legal_actions}"

# wrong zobrist hash when a pawn is removed by en passant #1078
state = from_fen("1nbqkb1r/rp1p1pp1/8/p1pPp3/4P1n1/5Pp1/PPPBK1PP/RN1Q1B1R w k e6 0 11")
state.save_svg("tests/assets/chess/buggy_samples_015.svg")
state = step(state, jnp.int32(2088))
state.save_svg("tests/assets/chess/buggy_samples_016.svg")
assert (state._x.zobrist_hash == jax.jit(_zobrist_hash)(state._x)).all()

# wrong zobrist hash due to queen promotion #1078
state = from_fen("B7/8/8/1P6/1k3K2/5P2/6p1/1B6 b - - 1 102")
state.save_svg("tests/assets/chess/buggy_samples_017.svg")
state = step(state, jnp.int32(3958))
state.save_svg("tests/assets/chess/buggy_samples_018.svg")
assert (state._x.zobrist_hash == jax.jit(_zobrist_hash)(state._x)).all()


def test_observe():
state = init(jax.random.PRNGKey(0))
Expand Down Expand Up @@ -1093,6 +1066,14 @@ def test_observe():
assert (state.observation[:, :, 14 * 4 + 13] == 1.).all() # rep


def test_zobrist_hash():
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
state = init(subkey)
assert (state._x.hash_history[0] == INIT_ZOBRIST_HASH).all()
assert (_zobrist_hash(state._x) == INIT_ZOBRIST_HASH).all()


def test_api():
import pgx
env = pgx.make("chess")
Expand Down

0 comments on commit 624f645

Please sign in to comment.