Skip to content

Commit

Permalink
feature(pu): add dummy_any_game_env to test alphazero in non-zero-sum…
Browse files Browse the repository at this point in the history
… two-player games or single player games
  • Loading branch information
puyuan1996 committed Jul 18, 2024
1 parent a0bf161 commit 2a331da
Show file tree
Hide file tree
Showing 6 changed files with 691 additions and 104 deletions.
159 changes: 78 additions & 81 deletions lzero/mcts/ptree/ptree_az.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
Overview:
This code implements the Monte Carlo Tree Search (MCTS) algorithm with the integration of neural networks.
The Node class represents a node in the Monte Carlo tree and implements the basic functionalities expected in a node.
The MCTS class implements the specific search functionality and provides the optimal action through the ``get_next_action`` method.
This code implements the Monte Carlo Tree Search (MCTS) algorithm with the integration of neural networks.
The Node class represents a node in the Monte Carlo tree and implements the basic functionalities expected in a node.
The MCTS class implements the specific search functionality and provides the optimal action through the ``get_next_action`` method.
Compared to traditional MCTS, the introduction of value networks and policy networks brings several advantages.
During the expansion of nodes, it is no longer necessary to explore every single child node, but instead,
the child nodes are directly selected based on the prior probabilities provided by the neural network.
This reduces the breadth of the search. When estimating the value of leaf nodes, there is no need for a rollout;
During the expansion of nodes, it is no longer necessary to explore every single child node, but instead,
the child nodes are directly selected based on the prior probabilities provided by the neural network.
This reduces the breadth of the search. When estimating the value of leaf nodes, there is no need for a rollout;
instead, the value output by the neural network is used, which saves the depth of the search.
"""

Expand All @@ -24,9 +24,9 @@
class Node(object):
"""
Overview:
A class for a node in a Monte Carlo Tree. The properties of this class store basic information about the node,
such as its parent node, child nodes, and the number of times the node has been visited.
The methods of this class implement basic functionalities that a node should have, such as propagating the value back,
A class for a node in a Monte Carlo Tree. The properties of this class store basic information about the node,
such as its parent node, child nodes, and the number of times the node has been visited.
The methods of this class implement basic functionalities that a node should have, such as propagating the value back,
checking if the node is the root node, and determining if it is a leaf node.
"""

Expand Down Expand Up @@ -75,46 +75,36 @@ def update(self, value: float) -> None:
# Updates the sum of the values of all child nodes of this node.
self._value_sum += value

def update_recursive(self, leaf_value: float, battle_mode_in_simulation_env: str) -> None:
def update_recursive(self, leaf_value: float, negate: bool = True) -> None:
"""
Overview:
Update node information recursively.
The same game state has opposite values in the eyes of two players playing against each other.
The value of a node is evaluated from the perspective of the player corresponding to its parent node.
In ``self_play_mode``, because the player corresponding to a node changes every step during the backpropagation process, the value needs to be negated once.
In ``play_with_bot_mode``, since all nodes correspond to the same player, the value does not need to be negated.
The same game state has opposite values in the eyes of two players playing against each other.
The value of a node is evaluated from the perspective of the player corresponding to its parent node.
In self-play mode, because the player corresponding to a node changes every step during the backpropagation process,
the value needs to be negated once.
In play-with-bot mode or single-player mode, since all nodes correspond to the same player,
the value does not need to be negated.
Arguments:
- leaf_value (:obj:`Float`): The value of the node.
- battle_mode_in_simulation_env (:obj:`str`): The mode of MCTS, can be 'self_play_mode' or 'play_with_bot_mode'.
"""
# Update the node information recursively based on the MCTS mode.
if battle_mode_in_simulation_env == 'self_play_mode':
# Update the current node's information.
self.update(leaf_value)
# If the current node is the root node, return.
if self.is_root():
return
# Update the parent node's information recursively. When propagating the value back to the parent node,
# the value needs to be negated once because the perspective of evaluation has changed.
self._parent.update_recursive(-leaf_value, battle_mode_in_simulation_env)
if battle_mode_in_simulation_env == 'play_with_bot_mode':
# Update the current node's information.
self.update(leaf_value)
# If the current node is the root node, return.
if self.is_root():
return
# Update the parent node's information recursively. In ``play_with_bot_mode``, since the nodes' values
# are always evaluated from the perspective of the agent player, there is no need to negate the value
# during value propagation.
self._parent.update_recursive(leaf_value, battle_mode_in_simulation_env)
- negate (:obj:`bool`): Whether to negate the value during backpropagation.
"""
# Update the current node's information.
self.update(leaf_value)
# If the current node is the root node, return.
if self.is_root():
return
# Update the parent node's information recursively.
# If negate is True, the perspective of evaluation has changed, so the value needs to be negated.
self._parent.update_recursive(-leaf_value if negate else leaf_value, negate)

def is_leaf(self) -> bool:
"""
Overview:
Check if the current node is a leaf node or not.
Returns:
- output (:obj:`Bool`): If self._children is empty, it means that the node has not
- output (:obj:`Bool`): If self._children is empty, it means that the node has not
been expanded yet, which indicates that the node is a leaf node.
"""
# Returns True if the node is a leaf node (i.e., has no children), and False otherwise.
Expand All @@ -131,7 +121,7 @@ def is_root(self) -> bool:
return self._parent is None

@property
def parent(self) -> None:
def parent(self) -> "Node":
"""
Overview:
Get the parent node of the current node.
Expand All @@ -141,17 +131,17 @@ def parent(self) -> None:
return self._parent

@property
def children(self) -> None:
def children(self) -> Dict[int, "Node"]:
"""
Overview:
Get the dictionary of children nodes of the current node.
Returns:
- output (:obj:`dict`): A dictionary representing the children of the current node.
- output (:obj:`dict`): A dictionary representing the children of the current node.
"""
return self._children

@property
def visit_count(self) -> None:
def visit_count(self) -> int:
"""
Overview:
Get the number of times the current node has been visited.
Expand All @@ -164,8 +154,8 @@ def visit_count(self) -> None:
class MCTS(object):
"""
Overview:
A class for Monte Carlo Tree Search (MCTS). The methods in this class implement the steps involved in MCTS, such as selection and expansion.
Based on this, the ``_simulate`` method is used to traverse from the root node to a leaf node.
A class for Monte Carlo Tree Search (MCTS). The methods in this class implement the steps involved in MCTS, such as selection and expansion.
Based on this, the ``_simulate`` method is used to traverse from the root node to a leaf node.
Finally, by repeatedly calling ``_simulate`` through ``get_next_action``, the optimal action is obtained.
"""

Expand Down Expand Up @@ -200,7 +190,7 @@ def get_next_action(
self,
state_config_for_simulate_env_reset: Dict[str, Any],
policy_forward_fn: Callable,
temperature: int = 1.0,
temperature: float = 1.0,
sample: bool = True
) -> Tuple[int, List[float]]:
"""
Expand All @@ -220,9 +210,9 @@ def get_next_action(
root = Node()

self.simulate_env.reset(
start_player_index=state_config_for_simulate_env_reset.start_player_index,
init_state=state_config_for_simulate_env_reset.init_state,
)
start_player_index=state_config_for_simulate_env_reset.get('start_player_index', 0),
init_state=state_config_for_simulate_env_reset.get('init_state', None),
)
# Expand the root node by adding children to it.
self._expand_leaf_node(root, self.simulate_env, policy_forward_fn)

Expand All @@ -234,8 +224,8 @@ def get_next_action(
for n in range(self._num_simulations):
# Initialize the simulated environment and reset it to the root node.
self.simulate_env.reset(
start_player_index=state_config_for_simulate_env_reset.start_player_index,
init_state=state_config_for_simulate_env_reset.init_state,
start_player_index=state_config_for_simulate_env_reset.get('start_player_index', 0),
init_state=state_config_for_simulate_env_reset.get('init_state', None),
)
# Set the battle mode adopted by the environment during the MCTS process.
# In ``self_play_mode``, when the environment calls the step function once, it will play one move based on the incoming action.
Expand All @@ -261,11 +251,9 @@ def get_next_action(
# Calculate the action probabilities based on the visit counts and temperature.
# When the visit count of a node is 0, then the corresponding action probability will be 0 in order to prevent the selection of illegal actions.
visits_t = torch.as_tensor(visits, dtype=torch.float32)
visits_t = torch.pow(visits_t, 1/temperature)
visits_t = torch.pow(visits_t, 1 / temperature)
action_probs = (visits_t / visits_t.sum()).numpy()

# action_probs = nn.functional.softmax(1.0 / temperature * np.log(torch.as_tensor(visits) + 1e-10), dim=0).numpy()

# Choose the next action to take based on the action probabilities.
if sample:
action = np.random.choice(actions, p=action_probs)
Expand Down Expand Up @@ -293,10 +281,17 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
break
simulate_env.step(action)

done, winner = simulate_env.get_done_winner()
result = simulate_env.get_done_winner()
# Determine the output format based on the length of the result
if len(result) == 2:
done, winner = result
eval_return_current_player = None
else:
done, winner, _, eval_return_current_player = result

"""
in ``self_play_mode``, the leaf_value is calculated from the perspective of player ``simulate_env.current_player``.
in ``play_with_bot_mode``, the leaf_value is calculated from the perspective of player 1.
in ``play_with_bot_mode`` or single-player mode, the leaf_value is calculated from the perspective of the agent.
"""

if not done:
Expand All @@ -306,28 +301,32 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
# game state from the perspective of player 1.
leaf_value = self._expand_leaf_node(node, simulate_env, policy_forward_fn)
else:
if simulate_env.battle_mode_in_simulation_env == 'self_play_mode':
# In a tie game, the value corresponding to a terminal node is 0.
if winner == -1:
leaf_value = 0
else:
# To maintain consistency with the perspective of the neural network, the value of a terminal
# node is also calculated from the perspective of the current_player of the terminal node,
# which is convenient for subsequent updates.
leaf_value = 1 if simulate_env.current_player == winner else -1

if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode':
# in ``play_with_bot_mode``, the leaf_value should be transformed to the perspective of player 1.
if winner == -1:
leaf_value = 0
elif winner == 1:
leaf_value = 1
elif winner == 2:
leaf_value = -1
if simulate_env.non_zero_sum:
# NOTE: for single-player game and non_zero_sum two-player game
leaf_value = eval_return_current_player
else:
if simulate_env.battle_mode_in_simulation_env == 'self_play_mode':
# In a tie game, the value corresponding to a terminal node is 0.
if winner == -1:
leaf_value = 0
else:
# To maintain consistency with the perspective of the neural network, the value of a terminal
# node is also calculated from the perspective of the current_player of the terminal node,
# which is convenient for subsequent updates.
leaf_value = 1 if simulate_env.current_player == winner else -1

elif simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode':
# in ``play_with_bot_mode``, the leaf_value should be transformed to the perspective of player 1.
if winner == -1:
leaf_value = 0
elif winner == 1:
leaf_value = 1
elif winner == 2:
leaf_value = -1

# Update value and visit count of nodes in this traversal.
if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode':
node.update_recursive(leaf_value, simulate_env.battle_mode_in_simulation_env)
if simulate_env.battle_mode_in_simulation_env in ['play_with_bot_mode', 'single_player_mode']:
node.update_recursive(leaf_value, negate=False)
elif simulate_env.battle_mode_in_simulation_env == 'self_play_mode':
# NOTE: e.g.
# to_play: 1 ----------> 2 ----------> 1 ----------> 2
Expand All @@ -337,7 +336,7 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
# leaf_value is calculated from the perspective of player 1, leaf_value = value_func(s3),
# but node.value should be the value of E[q(s2, action)], i.e. calculated from the perspective of player 2.
# thus we add the negative when call update_recursive().
node.update_recursive(-leaf_value, simulate_env.battle_mode_in_simulation_env)
node.update_recursive(-leaf_value, negate=True)

def _select_child(self, node: Node, simulate_env: Type[BaseEnv]) -> Tuple[Union[int, float], Node]:
"""
Expand All @@ -349,14 +348,13 @@ def _select_child(self, node: Node, simulate_env: Type[BaseEnv]) -> Tuple[Union[
- action (:obj:`Int`): choose the action with the highest ucb score.
- child (:obj:`Node`): the child node reached by executing the action with the highest ucb score.
"""
# assert list(node.children.keys()) == simulate_env.legal_actions
action = None
child = None
best_score = -9999999
# Iterate over each child of the current node.
for action_tmp, child_tmp in node.children.items():
"""
Check if the action is present in the list of legal actions for the current environment.
Check if the action is present in the list of legal actions for the current environment.
This check is relevant only when the agent is training in "play_with_bot_mode" and the bot's actions involve strong randomness.
"""
if action_tmp in simulate_env.legal_actions:
Expand Down Expand Up @@ -393,7 +391,6 @@ def _expand_leaf_node(self, node: Node, simulate_env: Type[BaseEnv], policy_forw
# the current node.
if action in simulate_env.legal_actions:
node.children[action] = Node(parent=node, prior_p=prior_p)

# Return the value of the leaf node.
return leaf_value

Expand All @@ -402,7 +399,7 @@ def _ucb_score(self, parent: Node, child: Node) -> float:
Overview:
Compute UCB score. The score for a node is based on its value, plus an exploration bonus based on the prior.
For more details, please refer to this paper: http://gauss.ececs.uc.edu/Workshops/isaim2010/papers/rosin.pdf
UCB = Q(s,a) + P(s,a) \cdot \frac{ \sqrt{N(\text{parent})}}{1+N(\text{child})} \cdot \left(c_1 + \log\left(\frac{N(\text{parent})+c_2+1}{c_2}\right)\right)
UCB = Q(s,a) + P(s,a) * sqrt(N(parent) / (1 + N(child))) * (c_1 + log((N(parent) + c_2 + 1) / c_2))
- Q(s,a): value of a child node.
- P(s,a): The prior of a child node.
- N(parent): The number of the visiting of the parent node.
Expand All @@ -413,7 +410,7 @@ def _ucb_score(self, parent: Node, child: Node) -> float:
- parent (:obj:`Class Node`): Current node.
- child (:obj:`Class Node`): Current node's child.
Returns:
- score (:obj:`Bool`): The UCB score.
- score (:obj:`Float`): The UCB score.
"""
# Compute the value of parameter pb_c using the formula of the UCB algorithm.
pb_c = math.log((parent.visit_count + self._pb_c_base + 1) / self._pb_c_base) + self._pb_c_init
Expand Down Expand Up @@ -441,4 +438,4 @@ def _add_exploration_noise(self, node: Node) -> None:
frac = self._root_noise_weight
# Update the prior probability of each child node with the exploration noise.
for a, n in zip(actions, noise):
node.children[a].prior_p = node.children[a].prior_p * (1 - frac) + n * frac
node.children[a].prior_p = node.children[a].prior_p * (1 - frac) + n * frac
11 changes: 11 additions & 0 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,17 @@ def _get_simulation_env(self):
else:
raise NotImplementedError
self.simulate_env = Connect4Env(connect4_alphazero_config.env)
elif self._cfg.simulation_env_id == 'dummy_any_game':
from zoo.board_games.tictactoe.envs.dummy_any_game_env import AnyGameEnv
if self._cfg.simulation_env_config_type == 'single_player_mode':
from zoo.board_games.tictactoe.config.dummy_any_game_alphazero_single_player_mode_config import \
dummy_any_game_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.tictactoe.config.dummy_any_game_alphazero_self_play_mode_config import \
dummy_any_game_alphazero_config
else:
raise NotImplementedError
self.simulate_env = AnyGameEnv(dummy_any_game_alphazero_config.env)
else:
raise NotImplementedError

Expand Down
Loading

0 comments on commit 2a331da

Please sign in to comment.