Skip to content

Commit

Permalink
fix(pu): polish all mlp model and related configs (#26)
Browse files Browse the repository at this point in the history
* feature(pu): add discrete_action_encoding_type option

* fix(pu): don't use last_linear_layer_init_zero in the middle layers in mlp model

* fix(pu): add activation in the final layers of common fcn and representation network

* feature(pu): add norm type option in ez mlp model

* polish(pu): polish mz mlp model

* debug(pu): use the same common and muzero_model_mlp with branch main-before-release

* polish(pu): set output_norm=True for all MLP model

* polish(pu): polish output_norm and output_activation

* polish(pu): add res_connection_in_dynamics option in muzero_mlp_model

* fix(pu): fix output_norm, output_act and init_zero in mlp model

* polish(pu): rename hidden_state to latent_state in policy, polish mlp config

* style(pu): yapf format

* polish(pu): polish norm_type
  • Loading branch information
puyuan1996 authored May 7, 2023
1 parent 0c6cc13 commit 1d2b8f2
Show file tree
Hide file tree
Showing 30 changed files with 557 additions and 259 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,7 @@ events.*
/test_*
# LightZero special key
/lzero/mcts/**/*.cpp
/zoo/**/*.c
/lzero/mcts/**/*.so
/lzero/mcts/**/*.h
!/lzero/mcts/**/lib
Expand Down
4 changes: 3 additions & 1 deletion lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
if self._cfg.use_root_value:
# use the root values from MCTS, as in EfficiientZero
# the root values have limited improvement but require much more GPU actors;
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero')
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(
network_output, data_type='muzero'
)
reward_pool = reward_pool.squeeze().tolist()
policy_logits_pool = policy_logits_pool.tolist()
noises = [
Expand Down
11 changes: 6 additions & 5 deletions lzero/mcts/ptree/ptree_ez.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def __init__(self, prior: float, legal_actions: List = None, action_space_size:
self.parent_value_prefix = 0 # only used in update_tree_q method

def expand(
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float,
policy_logits: List[float]
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float, policy_logits: List[float]
) -> None:
"""
Overview:
Expand Down Expand Up @@ -286,6 +285,7 @@ def __init__(self, num: int) -> None:
self.last_actions = []
self.search_lens = []


def select_child(
root: Node, min_max_stats: MinMaxStats, pb_c_base: float, pb_c_int: float, discount_factor: float,
mean_q: float, players: int
Expand Down Expand Up @@ -431,7 +431,6 @@ def batch_traverse(
is_root = 1
search_len = 0
results.search_paths[i].append(node)

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -515,7 +514,7 @@ def backpropagate(
path_len = len(search_path)
for i in range(path_len - 1, -1, -1):
node = search_path[i]

node.value_sum += bootstrap_value if node.to_play == to_play else -bootstrap_value

node.visit_count += 1
Expand All @@ -536,7 +535,9 @@ def backpropagate(
min_max_stats.update(true_reward + discount_factor * -node.value)

# true_reward is in the perspective of current player of node
bootstrap_value = (-true_reward if node.to_play == to_play else true_reward) + discount_factor * bootstrap_value
bootstrap_value = (
-true_reward if node.to_play == to_play else true_reward
) + discount_factor * bootstrap_value


def batch_backpropagate(
Expand Down
8 changes: 4 additions & 4 deletions lzero/mcts/ptree/ptree_sez.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def __init__(
self.batch_index = 0

def expand(
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float,
policy_logits: List[float]
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float, policy_logits: List[float]
) -> None:
"""
Overview:
Expand Down Expand Up @@ -614,7 +613,6 @@ def batch_traverse(
is_root = 1
search_len = 0
results.search_paths[i].append(node)

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -726,7 +724,9 @@ def backpropagate(
min_max_stats.update(true_reward + discount_factor * -node.value)

# true_reward is in the perspective of current player of node
bootstrap_value = (-true_reward if node.to_play == to_play else true_reward) + discount_factor * bootstrap_value
bootstrap_value = (
-true_reward if node.to_play == to_play else true_reward
) + discount_factor * bootstrap_value


def batch_backpropagate(
Expand Down
23 changes: 12 additions & 11 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ez_ctree
from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree


# ==============================================================
# EfficientZero
# ==============================================================
Expand Down Expand Up @@ -93,7 +92,7 @@ def search(
# preparation some constant
batch_size = roots.num
pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor

# the data storage of latent states: storing the latent state of all the nodes in one search.
latent_state_batch_in_search_path = [latent_state_roots]
# the data storage of value prefix hidden states in LSTM
Expand All @@ -118,7 +117,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand All @@ -143,7 +141,6 @@ def search(
).unsqueeze(0)
# .long() is only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Expand All @@ -156,7 +153,10 @@ def search(
)
if not model.training:
# if not in training, obtain the scalars of the value/value_prefix
[network_output.latent_state, network_output.policy_logits, network_output.value, network_output.value_prefix] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand Down Expand Up @@ -187,7 +187,7 @@ def search(
reward_hidden_state_c_batch.append(reward_latent_state_batch[0])
reward_hidden_state_h_batch.append(reward_latent_state_batch[1])

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down Expand Up @@ -260,7 +260,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m

def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int,
List[Any]]
List[Any]]
) -> None:
"""
Overview:
Expand Down Expand Up @@ -296,7 +296,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -324,8 +323,10 @@ def search(

if not model.training:
# if not in training, obtain the scalars of the value/reward
[network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand All @@ -340,7 +341,7 @@ def search(
value_batch = network_output.value.reshape(-1).tolist()
policy_logits_batch = network_output.policy_logits.tolist()

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down
9 changes: 5 additions & 4 deletions lzero/mcts/tree_search/mcts_ctree_sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -153,7 +152,6 @@ def search(
else:
# discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(device).long()

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Expand All @@ -166,7 +164,10 @@ def search(
)
if not model.training:
# if not in training, obtain the scalars of the value/value_prefix
[network_output.latent_state, network_output.policy_logits, network_output.value, network_output.value_prefix] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand Down Expand Up @@ -196,7 +197,7 @@ def search(
reward_hidden_state_c_pool.append(reward_latent_state_batch[0])
reward_hidden_state_h_pool.append(reward_latent_state_batch[1])

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down
31 changes: 16 additions & 15 deletions lzero/mcts/tree_search/mcts_ptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand All @@ -140,10 +139,10 @@ def search(
hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy])

latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)
).to(self._cfg.device).unsqueeze(0)
hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)
).to(self._cfg.device).unsqueeze(0)
hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device
).unsqueeze(0)
hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device
).unsqueeze(0)
# .long() is only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()
"""
Expand All @@ -159,8 +158,10 @@ def search(

if not model.training:
# if not in training, obtain the scalars of the value/reward
[network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand Down Expand Up @@ -190,7 +191,7 @@ def search(
reward_hidden_state_c_batch.append(reward_latent_state_batch[0])
reward_hidden_state_h_batch.append(reward_latent_state_batch[1])

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down Expand Up @@ -297,10 +298,9 @@ def search(
# prepare a result wrapper to transport results between python and c++ parts
results = tree_muzero.SearchResults(num=batch_size)

# latent_state_index_in_search_path: The first index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, that is, the search depth.
# latent_state_index_in_batch: The second index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# latent_state_index_in_search_path: The first index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, that is, the search depth.
# latent_state_index_in_batch: The second index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand All @@ -315,7 +315,6 @@ def search(
latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
# only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Expand All @@ -327,8 +326,10 @@ def search(

if not model.training:
# if not in training, obtain the scalars of the value/reward
[network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand All @@ -343,7 +344,7 @@ def search(
reward_batch = network_output.reward.reshape(-1).tolist()
policy_logits_batch = network_output.policy_logits.tolist()

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down
Loading

0 comments on commit 1d2b8f2

Please sign in to comment.