Skip to content

Commit f26e9e3

Browse files
authored
[Game] Rename returns to rewards (#1161)
1 parent 18de9a8 commit f26e9e3

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

pgx/_src/games/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@runtime_checkable
11-
class TwoPlayerPerfectInfoGame(Protocol[T]):
11+
class GameProtocol(Protocol[T]):
1212
def init(self) -> T:
1313
...
1414

@@ -24,5 +24,5 @@ def legal_action_mask(self, state: T) -> Array:
2424
def is_terminal(self, state: T) -> Array:
2525
...
2626

27-
def returns(self, state: T) -> Array:
27+
def rewards(self, state: T) -> Array:
2828
...

pgx/_src/games/connect_four.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def is_terminal(self, state: GameState) -> Array:
6363
board2d = state.board.reshape(6, 7)
6464
return (state.winner >= 0) | jnp.all((board2d >= 0).sum(axis=0) == 6)
6565

66-
def returns(self, state: GameState) -> Array:
66+
def rewards(self, state: GameState) -> Array:
6767
return jax.lax.select(
6868
state.winner >= 0,
6969
jnp.float32([-1, -1]).at[state.winner].set(1),

pgx/_src/games/go.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def is_terminal(self, state: GameState) -> Array:
126126
timeover = self.max_termination_steps <= state.step_count
127127
return two_consecutive_pass | state.is_psk | timeover
128128

129-
def returns(self, state: GameState) -> Array:
129+
def rewards(self, state: GameState) -> Array:
130130
score = _count_point(state, self.size)
131131
rewards = jax.lax.select(
132132
score[0] - self.komi > score[1],

pgx/_src/games/tic_tac_toe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def legal_action_mask(self, state: GameState) -> Array:
6060
def is_terminal(self, state: GameState) -> Array:
6161
return (state.winner >= 0) | jnp.all(state.board != -1)
6262

63-
def returns(self, state: GameState) -> Array:
63+
def rewards(self, state: GameState) -> Array:
6464
return jax.lax.select(
6565
state.winner >= 0,
6666
jnp.float32([-1, -1]).at[state.winner].set(1),

pgx/connect_four.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _step(self, state: core.State, action: Array, key) -> State:
5757
assert isinstance(state, State)
5858
legal_action_mask = self._game.legal_action_mask(state._x)
5959
terminated = self._game.is_terminal(state._x)
60-
rewards = self._game.returns(state._x)
60+
rewards = self._game.rewards(state._x)
6161
should_flip = state.current_player != state._x.color
6262
rewards = jax.lax.select(should_flip, jnp.flip(rewards), rewards)
6363
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))

pgx/go.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _step(self, state: core.State, action: Array, key) -> State:
7373
return state.replace( # type:ignore
7474
current_player=state._player_order[x.color],
7575
legal_action_mask=self._game.legal_action_mask(x),
76-
rewards=self._game.returns(x)[state._player_order],
76+
rewards=self._game.rewards(x)[state._player_order],
7777
terminated=self._game.is_terminal(x),
7878
_x=x,
7979
)

pgx/tic_tac_toe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _step(self, state: core.State, action: Array, key) -> State:
5656
assert isinstance(state, State)
5757
legal_action_mask = self._game.legal_action_mask(state._x)
5858
terminated = self._game.is_terminal(state._x)
59-
rewards = self._game.returns(state._x)
59+
rewards = self._game.rewards(state._x)
6060
should_flip = state.current_player != state._x.color
6161
rewards = jax.lax.select(should_flip, jnp.flip(rewards), rewards)
6262
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))

0 commit comments

Comments
 (0)