diff --git a/pgx/_src/games/go.py b/pgx/_src/games/go.py index 1c86b94a8..9593a406a 100644 --- a/pgx/_src/games/go.py +++ b/pgx/_src/games/go.py @@ -30,7 +30,6 @@ class GameState(NamedTuple): consecutive_pass_count: Array = jnp.int32(0) ko: Array = jnp.int32(-1) # by SSK is_psk: Array = jnp.bool_(False) - hash: Array = jnp.zeros(2, dtype=jnp.uint32) hash_history: Array = jnp.zeros((19 * 19 * 2, 2), dtype=jnp.uint32) @property @@ -67,8 +66,8 @@ def step(self, state: GameState, action: Array) -> GameState: board_history = board_history.at[0].set(jnp.clip(state.board, -1, 1).astype(jnp.int32)) state = state._replace(board_history=board_history) # check PSK - hash = _compute_hash(state) - state = state._replace(hash=hash, hash_history=state.hash_history.at[state.step_count].set(hash)) + hash_ = _compute_hash(state) + state = state._replace(hash_history=state.hash_history.at[state.step_count].set(hash_)) state = state._replace(is_psk=_is_psk(state)) # increment turns state = state._replace(step_count=state.step_count + 1) @@ -210,7 +209,8 @@ def _compute_hash(state: GameState): def _is_psk(state: GameState): not_passed = state.consecutive_pass_count == 0 - has_same_hash = (state.hash == state.hash_history).all(axis=-1).sum() > 1 + curr_hash = state.hash_history[state.step_count] + has_same_hash = (curr_hash == state.hash_history).all(axis=-1).sum() > 1 return not_passed & has_same_hash