diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 758dc655f..9c0fb5d6c 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple, Optional +from typing import NamedTuple import jax import jax.numpy as jnp @@ -178,7 +178,7 @@ def step(self, state: GameState, action: Array) -> GameState: state = state._replace(step_count=state.step_count + 1) return state - def observe(self, state: GameState, color: Optional[Array] = None) -> Array: + def observe(self, state: GameState, color: Array | None = None) -> Array: if color is None: color = state.color ones = jnp.ones((1, 8, 8), dtype=jnp.float32)